# ASR-Based Subtitle Generator for Legacy Internet Clips

This notebook implements a comprehensive subtitle generation system using state-of-the-art open-source ASR models, optimized for processing "old clips from the internet" with challenging audio characteristics.

## Models Supported:
- **OLMoASR** (Primary recommendation) - Transparent, competitive performance
- **Wav2Vec 2.0** (For fine-tuning) - Best for domain-specific adaptation  
- **Whisper** (Fallback) - Robust multilingual option

---

## 1. Setup and Installation

First, let's install the required dependencies and set up our environment.

In [None]:
# Install required packages (run this cell first)
!pip install torch torchaudio transformers datasets accelerate librosa soundfile ipython

# For Google Colab users, uncomment the next line:
# !pip install torch torchaudio transformers datasets accelerate librosa soundfile --quiet

In [None]:
# Import required libraries
import os
import torch
import torchaudio
import json
import time
import warnings
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from IPython.display import display, HTML, Audio
import matplotlib.pyplot as plt
import numpy as np

from transformers import (
    AutoProcessor, 
    AutoModelForSpeechSeq2Seq,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
    pipeline
)

warnings.filterwarnings("ignore")
print("✓ Libraries imported successfully")
print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name()}")

## 2. Core Classes and Functions

Let's define our subtitle segment data structure and the main ASR generator class.

In [None]:
@dataclass
class SubtitleSegment:
    """Represents a subtitle segment with timing and text"""
    start_time: float
    end_time: float
    text: str
    confidence: Optional[float] = None
    
    def duration(self) -> float:
        """Get segment duration in seconds"""
        return self.end_time - self.start_time
    
    def __str__(self):
        return f"[{self.start_time:.1f}s - {self.end_time:.1f}s]: {self.text}"

In [None]:
class ASRSubtitleGenerator:
    """
    ASR-based subtitle generator optimized for legacy internet clips
    
    Supports multiple SOTA models:
    - OLMoASR (primary recommendation)
    - Wav2Vec 2.0 (for fine-tuning scenarios)  
    - Whisper (fallback option)
    """
    
    SUPPORTED_MODELS = {
        'olmoasr-large': 'allenai/OLMoASR-large.en-v2',
        'olmoasr-base': 'allenai/OLMoASR-base.en-v2', 
        'olmoasr-tiny': 'allenai/OLMoASR-tiny.en-v2',
        'wav2vec2-base': 'facebook/wav2vec2-base-960h',
        'wav2vec2-large': 'facebook/wav2vec2-large-960h-lv60-self',
        'whisper-tiny': 'openai/whisper-tiny',
        'whisper-base': 'openai/whisper-base',
        'whisper-small': 'openai/whisper-small', 
        'whisper-medium': 'openai/whisper-medium',
        'whisper-large': 'openai/whisper-large-v3'
    }
    
    def __init__(self, model_name: str = 'olmoasr-large', device: str = 'auto'):
        """
        Initialize the ASR subtitle generator
        
        Args:
            model_name: Model identifier from SUPPORTED_MODELS
            device: Computing device ('auto', 'cpu', 'cuda', 'cuda:0', etc.)
        """
        self.model_name = model_name
        self.model_id = self.SUPPORTED_MODELS.get(model_name)
        
        if not self.model_id:
            raise ValueError(f"Unsupported model: {model_name}. "
                           f"Supported models: {list(self.SUPPORTED_MODELS.keys())}")
        
        # Auto-detect device
        if device == 'auto':
            if torch.cuda.is_available():
                self.device = 'cuda'
                print(f"✓ Using GPU: {torch.cuda.get_device_name()}")
            else:
                self.device = 'cpu'
                print("✓ Using CPU (consider GPU for faster processing)")
        else:
            self.device = device
            
        self.processor = None
        self.model = None
        self.pipe = None
        
        # Load model and processor
        self._load_model()
        
    def _load_model(self):
        """Load the specified ASR model and processor"""
        print(f"Loading {self.model_name} model...")
        
        try:
            if 'whisper' in self.model_name:
                # Use Whisper-specific loading
                self.processor = WhisperProcessor.from_pretrained(self.model_id)
                self.model = WhisperForConditionalGeneration.from_pretrained(self.model_id)
                self.model.to(self.device)
                
                # Create pipeline for easier inference
                self.pipe = pipeline(
                    "automatic-speech-recognition",
                    model=self.model,
                    tokenizer=self.processor.tokenizer,
                    feature_extractor=self.processor.feature_extractor,
                    device=self.device,
                    return_timestamps=True
                )
                
            elif 'wav2vec2' in self.model_name:
                # Use Wav2Vec 2.0 specific loading
                self.processor = Wav2Vec2Processor.from_pretrained(self.model_id)
                self.model = Wav2Vec2ForCTC.from_pretrained(self.model_id)
                self.model.to(self.device)
                
            elif 'olmoasr' in self.model_name:
                # Use OLMoASR (transformer-based)
                self.processor = AutoProcessor.from_pretrained(self.model_id)
                self.model = AutoModelForSpeechSeq2Seq.from_pretrained(self.model_id)
                self.model.to(self.device)
                
                # Create pipeline for easier inference
                self.pipe = pipeline(
                    "automatic-speech-recognition",
                    model=self.model,
                    tokenizer=self.processor.tokenizer,
                    feature_extractor=self.processor.feature_extractor, 
                    device=self.device,
                    return_timestamps=True
                )
                
            print(f"✓ Successfully loaded {self.model_name}")
            
        except Exception as e:
            print(f"✗ Error loading model: {e}")
            raise
    
    def load_audio(self, audio_path: str, target_sr: int = 16000) -> torch.Tensor:
        """Load and preprocess audio file"""
        if not os.path.exists(audio_path):
            raise FileNotFoundError(f"Audio file not found: {audio_path}")
            
        try:
            # Load audio
            speech, sr = torchaudio.load(audio_path)
            
            # Convert stereo to mono if needed
            if speech.shape[0] > 1:
                speech = speech.mean(dim=0, keepdim=True)
            
            # Resample if needed
            if sr != target_sr:
                resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
                speech = resampler(speech)
            
            return speech.squeeze()
            
        except Exception as e:
            raise RuntimeError(f"Error loading audio {audio_path}: {e}")
    
    def transcribe_audio(self, audio_path: str, chunk_length_s: int = 30) -> List[SubtitleSegment]:
        """Transcribe audio file to subtitle segments"""
        print(f"Transcribing: {Path(audio_path).name}")
        start_time = time.time()
        
        try:
            if self.pipe and ('whisper' in self.model_name or 'olmoasr' in self.model_name):
                # Use pipeline for models that support it
                result = self.pipe(
                    audio_path,
                    chunk_length_s=chunk_length_s,
                    return_timestamps=True,
                    generate_kwargs={"task": "transcribe", "language": "english"}
                )
                
                segments = []
                if 'chunks' in result:
                    for chunk in result['chunks']:
                        segments.append(SubtitleSegment(
                            start_time=chunk['timestamp'][0] or 0,
                            end_time=chunk['timestamp'][1] or chunk_length_s,
                            text=chunk['text'].strip()
                        ))
                else:
                    # Fallback for single result
                    segments.append(SubtitleSegment(
                        start_time=0,
                        end_time=chunk_length_s,
                        text=result['text'].strip()
                    ))
                
            elif 'wav2vec2' in self.model_name:
                # Handle Wav2Vec2 (CTC-based)
                audio_data = self.load_audio(audio_path)
                
                segments = []
                chunk_samples = chunk_length_s * 16000
                
                for i in range(0, len(audio_data), chunk_samples):
                    chunk = audio_data[i:i+chunk_samples]
                    
                    inputs = self.processor(chunk, return_tensors="pt", sampling_rate=16000)
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    
                    with torch.no_grad():
                        logits = self.model(**inputs).logits
                        predicted_ids = torch.argmax(logits, dim=-1)
                        transcription = self.processor.batch_decode(predicted_ids)[0]
                    
                    if transcription.strip():
                        start_time_seg = i / 16000
                        end_time_seg = min((i + len(chunk)) / 16000, len(audio_data) / 16000)
                        
                        segments.append(SubtitleSegment(
                            start_time=start_time_seg,
                            end_time=end_time_seg,
                            text=transcription.strip()
                        ))
            
            processing_time = time.time() - start_time
            print(f"✓ Transcription completed in {processing_time:.1f}s")
            return segments
                
        except Exception as e:
            raise RuntimeError(f"Error during transcription: {e}")
    
    def format_time_srt(self, seconds: float) -> str:
        """Convert seconds to SRT timestamp format"""
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        secs = int(seconds % 60)
        millisecs = int((seconds % 1) * 1000)
        return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}"
    
    def export_srt(self, segments: List[SubtitleSegment], output_path: str):
        """Export subtitle segments to SRT format"""
        try:
            with open(output_path, 'w', encoding='utf-8') as f:
                for i, segment in enumerate(segments, 1):
                    if segment.text.strip():
                        f.write(f"{i}\n")
                        f.write(f"{self.format_time_srt(segment.start_time)} --> "
                               f"{self.format_time_srt(segment.end_time)}\n")
                        f.write(f"{segment.text}\n\n")
            
            print(f"✓ SRT file saved: {output_path}")
            
        except Exception as e:
            raise RuntimeError(f"Error saving SRT file: {e}")

print("✓ ASR Subtitle Generator class defined")

## 3. Model Comparison and Selection

Let's explore the available models and their characteristics.

In [None]:
# Display available models with recommendations
model_info = {
    'olmoasr-large': {'size': 'Large (~1.5GB)', 'best_for': 'Legacy clips, transparency', 'speed': 'Medium'},
    'olmoasr-base': {'size': 'Base (~500MB)', 'best_for': 'Balanced performance', 'speed': 'Fast'},
    'olmoasr-tiny': {'size': 'Tiny (~150MB)', 'best_for': 'Resource-constrained', 'speed': 'Very Fast'},
    'wav2vec2-base': {'size': 'Base (~360MB)', 'best_for': 'Clean audio, fine-tuning', 'speed': 'Fast'},
    'wav2vec2-large': {'size': 'Large (~1.3GB)', 'best_for': 'High accuracy, fine-tuning', 'speed': 'Medium'},
    'whisper-tiny': {'size': 'Tiny (~39MB)', 'best_for': 'Quick testing', 'speed': 'Very Fast'},
    'whisper-base': {'size': 'Base (~74MB)', 'best_for': 'General purpose', 'speed': 'Fast'},
    'whisper-small': {'size': 'Small (~244MB)', 'best_for': 'Good quality/speed trade-off', 'speed': 'Medium'},
    'whisper-medium': {'size': 'Medium (~769MB)', 'best_for': 'High quality transcription', 'speed': 'Medium'},
    'whisper-large': {'size': 'Large (~1.5GB)', 'best_for': 'Maximum accuracy', 'speed': 'Slow'}
}

print("📊 Available ASR Models:")
print("="*80)
for model, info in model_info.items():
    print(f"🔹 {model:<20} | Size: {info['size']:<15} | Best for: {info['best_for']:<25} | Speed: {info['speed']}")

print("\n🎯 Recommendations for Legacy Internet Clips:")
print("1. Primary: olmoasr-large (Best transparency + performance)")
print("2. Secondary: wav2vec2-large (For fine-tuning scenarios)")
print("3. Fallback: whisper-medium (Robust multilingual option)")

## 4. Audio Analysis and Visualization

Let's create some helper functions to analyze and visualize audio files before processing.

In [None]:
def analyze_audio(audio_path: str):
    """Analyze audio file characteristics"""
    try:
        # Load audio
        waveform, sample_rate = torchaudio.load(audio_path)
        
        # Basic info
        duration = waveform.shape[1] / sample_rate
        num_channels = waveform.shape[0]
        
        print(f"📁 File: {Path(audio_path).name}")
        print(f"⏱️  Duration: {duration:.1f} seconds ({duration//60:.0f}:{duration%60:04.1f})")
        print(f"🔊 Sample Rate: {sample_rate} Hz")
        print(f"📻 Channels: {num_channels} ({'Stereo' if num_channels == 2 else 'Mono'})")
        
        # Audio quality assessment
        if sample_rate < 16000:
            print("⚠️  Low sample rate detected - may affect ASR quality")
        if sample_rate > 44100:
            print("ℹ️  High sample rate - will be downsampled for ASR")
            
        return waveform, sample_rate, duration
        
    except Exception as e:
        print(f"❌ Error analyzing audio: {e}")
        return None, None, None

def plot_waveform(waveform, sample_rate, max_duration=30):
    """Plot audio waveform"""
    if waveform is None:
        return
        
    # Limit plot to max_duration seconds
    max_samples = int(max_duration * sample_rate)
    if waveform.shape[1] > max_samples:
        waveform_plot = waveform[:, :max_samples]
        duration_plot = max_duration
    else:
        waveform_plot = waveform
        duration_plot = waveform.shape[1] / sample_rate
    
    # Convert to mono for plotting
    if waveform_plot.shape[0] > 1:
        waveform_plot = waveform_plot.mean(dim=0)
    else:
        waveform_plot = waveform_plot[0]
    
    # Create time axis
    time_axis = torch.linspace(0, duration_plot, waveform_plot.shape[0])
    
    # Plot
    plt.figure(figsize=(12, 4))
    plt.plot(time_axis, waveform_plot)
    plt.title(f'Audio Waveform (first {duration_plot:.1f}s)')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Amplitude')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

print("✓ Audio analysis functions defined")

## 5. Interactive Model Testing

Now let's create an interactive section where you can test different models on your audio files.

In [None]:
# Configuration - Modify these variables for your use case
AUDIO_FILE_PATH = "your_audio_file.wav"  # 👈 Change this to your audio file path
MODEL_TO_TEST = "olmoasr-base"  # 👈 Change this to test different models
CHUNK_LENGTH = 30  # seconds

print(f"🎯 Configuration:")
print(f"   Audio file: {AUDIO_FILE_PATH}")
print(f"   Model: {MODEL_TO_TEST}")
print(f"   Chunk length: {CHUNK_LENGTH}s")
print("\n💡 Tip: Modify the variables above to test with your own files and models!")

In [None]:
# Step 1: Analyze your audio file (optional but recommended)
if os.path.exists(AUDIO_FILE_PATH):
    print("🔍 Analyzing audio file...")
    waveform, sr, duration = analyze_audio(AUDIO_FILE_PATH)
    
    if waveform is not None:
        # Plot waveform
        plot_waveform(waveform, sr)
        
        # Create audio player widget for Jupyter
        display(Audio(AUDIO_FILE_PATH))
        
else:
    print(f"❌ Audio file not found: {AUDIO_FILE_PATH}")
    print("\n📝 To test with your own audio:")
    print("1. Upload your audio file to the notebook directory")
    print("2. Update AUDIO_FILE_PATH variable above")
    print("3. Re-run this cell")
    
    # For demonstration, let's create a sample audio file
    print("\n🎵 Creating a sample audio file for demonstration...")
    sample_rate = 16000
    duration = 5  # 5 seconds
    frequency = 440  # A4 note
    
    t = torch.linspace(0, duration, int(sample_rate * duration))
    waveform = torch.sin(2 * torch.pi * frequency * t).unsqueeze(0)
    
    AUDIO_FILE_PATH = "sample_tone.wav"
    torchaudio.save(AUDIO_FILE_PATH, waveform, sample_rate)
    print(f"✓ Sample audio created: {AUDIO_FILE_PATH}")

In [None]:
# Step 2: Initialize the ASR generator
print(f"🚀 Initializing ASR generator with {MODEL_TO_TEST}...")

try:
    generator = ASRSubtitleGenerator(model_name=MODEL_TO_TEST, device='auto')
    print("✅ Generator initialized successfully!")
    
except Exception as e:
    print(f"❌ Error initializing generator: {e}")
    print("\n💡 Try using a smaller model like 'whisper-tiny' or 'olmoasr-tiny'")

In [None]:
# Step 3: Generate subtitles
if 'generator' in locals() and os.path.exists(AUDIO_FILE_PATH):
    print(f"🎬 Generating subtitles for {Path(AUDIO_FILE_PATH).name}...")
    
    try:
        # Transcribe the audio
        subtitle_segments = generator.transcribe_audio(AUDIO_FILE_PATH, chunk_length_s=CHUNK_LENGTH)
        
        # Display the results
        print("\n📜 Transcription Results:")
        for segment in subtitle_segments:
            print(segment)
            
        # Export to SRT
        output_srt_path = Path(AUDIO_FILE_PATH).stem + f"_{MODEL_TO_TEST}.srt"
        generator.export_srt(subtitle_segments, output_srt_path)
        
    except Exception as e:
        print(f"❌ An error occurred during subtitle generation: {e}")
else:
    print("\nSkipping subtitle generation. Please ensure the generator is initialized and the audio file exists.")