<a href="https://colab.research.google.com/github/Yknld/ydl_api_ng/blob/main/Colab_Whisper_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================
# CELL 0: Setup and Dependencies
# ============================================

# Install required packages
print("📦 Installing required packages...")
!pip install -q transformers datasets accelerate librosa soundfile

# Mount Google Drive
print("📁 Mounting Google Drive...")
from google.colab import drive
drive.mount('/content/drive')

# Check GPU availability
import torch
print(f"🖥️  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name()}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Create necessary directories
import os
sermon_data_dir = "/content/drive/MyDrive/sermon_data"
models_dir = os.path.join(sermon_data_dir, "models")
os.makedirs(models_dir, exist_ok=True)
print(f"✅ Created directories: {models_dir}")

# Check disk space
import subprocess
result = subprocess.run(['df', '-h', '/content/drive/MyDrive'], capture_output=True, text=True)
print("💾 Disk space:")
print(result.stdout)

print("✅ Setup complete! Ready for training.")

📦 Installing required packages...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m115.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m97.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m57.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m42.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m20.0 MB/s[0m eta [36m0:00:00[0m
[2K 

ValueError: mount failed

In [None]:
# ============================================
# CELL 1: Batch Download All Sermon URLs
# ============================================

!pip install yt-dlp

import yt_dlp
import os
import re
from pathlib import Path

def extract_video_id(url):
    """Extract video ID from various YouTube URL formats"""
    patterns = [
        r'(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/)([^&\n?#]+)',
        r'youtube\.com\/watch\?.*v=([^&\n?#]+)'
    ]

    for pattern in patterns:
        match = re.search(pattern, url)
        if match:
            return match.group(1)
    return None

def download_sermon_with_transcript(url, output_dir="/content/drive/MyDrive/sermon_data"):
    """Download a single sermon with transcript"""
    video_id = extract_video_id(url)
    if not video_id:
        print(f"❌ Could not extract video ID from: {url}")
        return False

    # Configure yt-dlp options
    ydl_opts = {
        'format': 'bestaudio/best',  # Download audio only
        'outtmpl': f'{output_dir}/sermon_{video_id}.%(ext)s',
        'writesubtitles': True,      # Download subtitles
        'writeautomaticsub': True,   # Download auto-generated subs
        'subtitleslangs': ['en'],    # English subtitles
        'subtitlesformat': 'vtt',    # VTT format
        'postprocessors': [{
            'key': 'FFmpegExtractAudio',
            'preferredcodec': 'wav',
        }],
        'quiet': True,  # Reduce output noise
    }

    try:
        with yt_dlp.YoutubeDL(ydl_opts) as ydl:
            print(f"🔄 Downloading: {video_id}")
            info = ydl.extract_info(url, download=True)

            title = info.get('title', 'unknown')
            print(f"   ✅ Downloaded: {title[:50]}...")

            return True

    except Exception as e:
        print(f"   ❌ Error downloading {video_id}: {e}")
        return False

def process_sermon_urls():
    """Process all sermon URLs from the file"""

    # Sermon URLs from your file
    sermon_urls = [
    ]

    print("📺 BATCH DOWNLOADING SERMONS")
    print("=" * 50)
    print(f"🎯 Total sermons to download: {len(sermon_urls)}")
    print(f"📁 Output directory: /content/drive/MyDrive/sermon_data")
    print("=" * 50)

    successful_downloads = 0
    failed_downloads = 0

    for i, url in enumerate(sermon_urls, 1):
        print(f"\n📥 [{i}/{len(sermon_urls)}] Processing...")

        if download_sermon_with_transcript(url):
            successful_downloads += 1
        else:
            failed_downloads += 1

    # Summary
    print(f"\n📊 DOWNLOAD SUMMARY")
    print("=" * 50)
    print(f"✅ Successful downloads: {successful_downloads}")
    print(f"❌ Failed downloads: {failed_downloads}")
    print(f"📁 Files saved to: /content/drive/MyDrive/sermon_data")

    # List downloaded files
    if os.path.exists("/content/drive/MyDrive/sermon_data"):
        files = os.listdir("/content/drive/MyDrive/sermon_data")
        audio_files = [f for f in files if f.endswith('.wav')]
        vtt_files = [f for f in files if f.endswith('.vtt')]

        print(f"\n📋 Downloaded files:")
        print(f"   🎵 Audio files: {len(audio_files)}")
        for f in sorted(audio_files):
            print(f"      - {f}")

        print(f"   📝 Transcript files: {len(vtt_files)}")
        for f in sorted(vtt_files):
            print(f"      - {f}")

        if audio_files and vtt_files:
            print(f"\n🎯 Ready for Whisper training!")
            print(f"📊 You now have {len(audio_files)} sermon audio files")
            print(f"📝 with {len(vtt_files)} timestamped transcripts")
        else:
            print(f"\n⚠️  Some downloads may have failed")
            print(f"Check the output above for errors")

# Run the batch download
process_sermon_urls()

In [None]:
# 🔧 SETUP WHISPER TRAINING FOLDER
# Run this cell in your TRAINING COLAB before looking for training data

import os
from google.colab import drive

def setup_whisper_training_folder():
    """
    Ensure the Whisper_Training_Data folder exists in Google Drive
    Perfect for running in your training colab BEFORE processing
    """

    print("🔧 SETTING UP WHISPER TRAINING FOLDER...")
    print("=" * 50)

    try:
        # Mount Google Drive
        print("🔗 Mounting Google Drive...")
        drive.mount('/content/drive')
        print("✅ Google Drive mounted successfully!")

        # Create the main training data folder
        drive_path = "/content/drive/MyDrive"
        training_folder = f"{drive_path}/Whisper_Training_Data"

        print(f"📂 Checking for: Whisper_Training_Data")

        if not os.path.exists(training_folder):
            print("📁 Folder doesn't exist - creating it...")
            os.makedirs(training_folder, exist_ok=True)
            print(f"✅ Created: Whisper_Training_Data")
        else:
            print(f"✅ Found existing: Whisper_Training_Data")

        # Double-check it exists
        if os.path.exists(training_folder):
            print(f"\n🎉 SUCCESS! Folder is ready for training data")
            print(f"📍 Location: /content/drive/MyDrive/Whisper_Training_Data")

            # List any existing batches
            try:
                batch_folders = [f for f in os.listdir(training_folder)
                               if os.path.isdir(os.path.join(training_folder, f)) and f.startswith('training_batch_')]

                if batch_folders:
                    print(f"\n📊 Found {len(batch_folders)} existing training batches:")
                    for batch in sorted(batch_folders, reverse=True)[:3]:  # Show latest 3
                        print(f"   • {batch}")
                    if len(batch_folders) > 3:
                        print(f"   ... and {len(batch_folders) - 3} more")
                else:
                    print(f"\n📊 Folder is empty (ready for first batch)")

            except Exception as e:
                print(f"📊 Folder exists but couldn't list contents: {e}")

            print(f"\n✅ Your training colab should now work!")
            return True

        else:
            print(f"\n❌ Failed to create folder")
            print(f"💡 Try running this cell again")
            return False

    except Exception as e:
        print(f"❌ Error: {e}")
        print(f"💡 Try running this cell again")
        return False

# 🚀 RUN THE SETUP
print("🚀 Running Whisper training folder setup...")
success = setup_whisper_training_folder()

if success:
    print("\n" + "="*60)
    print("🎯 READY FOR TRAINING!")
    print("   Your training colab should now find the Whisper_Training_Data folder")
    print("   You can now run your training data processing")
    print("="*60)
else:
    print("\n" + "="*60)
    print("❌ SETUP FAILED")
    print("   Try running this cell again")
    print("   Check your Google Drive permissions")
    print("="*60)

🚀 Running Whisper training folder setup...
🔧 SETTING UP WHISPER TRAINING FOLDER...
🔗 Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Google Drive mounted successfully!
📂 Checking for: Whisper_Training_Data
✅ Found existing: Whisper_Training_Data

🎉 SUCCESS! Folder is ready for training data
📍 Location: /content/drive/MyDrive/Whisper_Training_Data

📊 Found 5 existing training batches:
   • training_batch_2025-07-26_14-22-20
   • training_batch_2025-07-26_08-32-16
   • training_batch_2025-07-26_08-30-35
   ... and 2 more

✅ Your training colab should now work!

🎯 READY FOR TRAINING!
   Your training colab should now find the Whisper_Training_Data folder
   You can now run your training data processing


In [None]:
# ============================================
# CELL 2: Data Preparation - YouTube OR Google Drive
# ============================================

import os
import re
import librosa
import numpy as np
from pathlib import Path
from datasets import Dataset
from google.colab import drive
import shutil

# ============================================
# QUICK CONFIG: Set to skip menu and use Google Drive directly
# ============================================
USE_GOOGLE_DRIVE_DIRECTLY = True  # Set to False if you want the menu

def setup_data_source(show_list_option=True):
    """
    Choose data source: YouTube download OR Google Drive batch
    Returns: dataset ready for training
    """

    print("📊 DATA SOURCE SELECTION")
    print("=" * 40)
    print("Choose your data source:")
    print("1. 📱 Google Drive batch (from your YouTube processor)")
    print("2. 🔄 YouTube download (original method)")

    if show_list_option:
        print("3. 📋 List available Google Drive batches (preview only)")
        choice = input("Enter choice (1/2/3): ").strip()
        valid_choices = ["1", "2", "3"]
    else:
        choice = input("Enter choice (1/2): ").strip()
        valid_choices = ["1", "2"]

    if choice == "3" and show_list_option:
        list_google_drive_batches()
        print("\n" + "="*40)
        print("💡 Now choose option 1 or 2 to load your data:")
        return setup_data_source(show_list_option=False)  # Don't show list option again

    elif choice == "1":
        return load_from_google_drive()

    elif choice == "2":
        return load_from_youtube_download()

    else:
        print("❌ Invalid choice")
        return setup_data_source(show_list_option)

def list_google_drive_batches():
    """List available Google Drive training batches"""

    print("🔗 LISTING GOOGLE DRIVE BATCHES...")

    # Mount Google Drive
    try:
        drive.mount('/content/drive')
        print("✅ Google Drive mounted")
    except Exception as e:
        print(f"❌ Failed to mount Google Drive: {e}")
        return

    drive_base = "/content/drive/MyDrive/Whisper_Training_Data"

    if not os.path.exists(drive_base):
        print("❌ No Whisper_Training_Data folder found in Google Drive")
        print("💡 Make sure your YouTube processor has uploaded batches first")
        return

    # List batches
    batches = []
    for item in os.listdir(drive_base):
        item_path = os.path.join(drive_base, item)
        if os.path.isdir(item_path) and item.startswith('training_batch_'):
            batches.append(item)

    batches = sorted(batches, reverse=True)  # Newest first

    if not batches:
        print("📂 No training batches found")
        return

    print("📁 AVAILABLE TRAINING BATCHES:")
    print("-" * 30)

    for i, batch in enumerate(batches, 1):
        batch_path = os.path.join(drive_base, batch)

        try:
            vtt_files = len(list(Path(batch_path).glob("*.vtt")))
            wav_files = len(list(Path(batch_path).glob("*.wav")))
            timestamp = batch.replace('training_batch_', '').replace('_', ' ')

            print(f"{i}. {batch}")
            print(f"   📅 Created: {timestamp}")
            print(f"   📊 Files: {vtt_files} VTT, {wav_files} WAV")
            print()
        except:
            print(f"{i}. {batch} (error reading)")

def load_from_google_drive():
    """Load training data from Google Drive batch"""

    print("🔗 LOADING FROM GOOGLE DRIVE...")

    # Mount Google Drive
    try:
        drive.mount('/content/drive')
        print("✅ Google Drive mounted")
    except Exception as e:
        print(f"❌ Failed to mount Google Drive: {e}")
        return None

    drive_base = "/content/drive/MyDrive/Whisper_Training_Data"

    if not os.path.exists(drive_base):
        print("❌ No Whisper_Training_Data folder found")
        return None

    # Get available batches
    batches = []
    for item in os.listdir(drive_base):
        item_path = os.path.join(drive_base, item)
        if os.path.isdir(item_path) and item.startswith('training_batch_'):
            batches.append(item)

    batches = sorted(batches, reverse=True)

    if not batches:
        print("📂 No training batches found")
        return None

    # Select batch
    print("📁 Available batches:")
    for i, batch in enumerate(batches[:5], 1):  # Show top 5
        timestamp = batch.replace('training_batch_', '').replace('_', ' ')
        print(f"  {i}. {batch} ({timestamp})")

    batch_choice = input(f"Enter batch number (1-{min(5, len(batches))}) or press Enter for newest: ").strip()

    if batch_choice.isdigit() and 1 <= int(batch_choice) <= min(5, len(batches)):
        selected_batch = batches[int(batch_choice) - 1]
    else:
        selected_batch = batches[0]  # Use newest
        print(f"🎯 Using newest batch: {selected_batch}")

    # Copy batch to local Colab storage
    source_path = os.path.join(drive_base, selected_batch)
    local_path = "/content/training_data"

    print(f"📥 Loading batch: {selected_batch}")
    print(f"📂 Copying to: {local_path}")

    # Clear and create local directory
    if os.path.exists(local_path):
        shutil.rmtree(local_path)
    os.makedirs(local_path)

    # Copy files
    files_copied = 0
    for file_path in Path(source_path).iterdir():
        if file_path.is_file() and file_path.suffix.lower() in ['.wav', '.vtt', '.txt']:
            dest_file = Path(local_path) / file_path.name
            shutil.copy2(file_path, dest_file)
            files_copied += 1

    print(f"✅ Copied {files_copied} files")

    # Create dataset from files (will handle VTT directly)
    return create_dataset_from_files(local_path)

# VTT conversion functions removed - now handled directly in create_dataset_from_files

def clean_vtt_formatting(text):
    """Remove VTT formatting tags"""

    # Remove <c> tags and inline timestamps
    text = re.sub(r'</?c>', '', text)
    text = re.sub(r'<\d{2}:\d{2}:\d{2}\.\d{3}>', '', text)
    text = re.sub(r'align:start position:\d+%', '', text)
    text = re.sub(r'\s+', ' ', text)

    return text.strip()

def load_from_youtube_download():
    """Load data from YouTube download (original method)"""

    print("🔄 LOADING FROM YOUTUBE DOWNLOAD...")
    print("💡 Make sure you've run the YouTube download cell first!")

    # Check for downloaded data
    data_dir = "/content/drive/MyDrive/sermon_data"

    if not os.path.exists(data_dir):
        print("❌ No YouTube download data found")
        print("💡 Run the YouTube download cell first")
        return None

    # Create dataset from YouTube downloads
    return create_dataset_from_files(data_dir)

def create_dataset_from_files(data_dir):
    """Create HuggingFace dataset from audio/text files"""

    print(f"📂 Creating dataset from: {data_dir}")

    # Find audio and VTT files
    data_path = Path(data_dir)
    wav_files = list(data_path.glob("*.wav"))
    vtt_files = list(data_path.glob("*.vtt"))
    txt_files = list(data_path.glob("*.txt"))

    print(f"📊 Found {len(wav_files)} WAV files, {len(vtt_files)} VTT files, {len(txt_files)} TXT files")

    if not wav_files:
        print("❌ No WAV files found!")
        return None

    # Create dataset items with raw audio data
    dataset_items = []
    processed = 0

    for wav_file in wav_files:
        base_name = wav_file.stem

        # Try to find VTT file first (preferred), then TXT
        vtt_file = data_path / f"{base_name}.vtt"
        txt_file = data_path / f"{base_name}.txt"

        # Also try finding files with similar names
        if not vtt_file.exists():
            for vtt_f in vtt_files:
                if base_name in vtt_f.stem or vtt_f.stem in base_name:
                    vtt_file = vtt_f
                    break

        if not txt_file.exists():
            for txt_f in txt_files:
                if base_name in txt_f.stem or txt_f.stem in base_name:
                    txt_file = txt_f
                    break

        # Process VTT file if available (creates multiple segments)
        if vtt_file.exists():
            try:
                # Load full audio file
                full_audio, sr = librosa.load(str(wav_file), sr=16000)

                # Parse VTT into segments
                with open(vtt_file, 'r', encoding='utf-8') as f:
                    vtt_content = f.read()

                segments = parse_vtt_segments(vtt_content)
                print(f"   📝 Found {len(segments)} VTT segments in {wav_file.name}")

                # Create training examples from VTT segments
                for i, segment in enumerate(segments):
                    start_time = segment['start']
                    end_time = segment['end']
                    text = segment['text']

                    # Extract audio segment
                    start_sample = int(start_time * sr)
                    end_sample = int(end_time * sr)

                    if end_sample > len(full_audio):
                        end_sample = len(full_audio)

                    if start_sample < len(full_audio) and end_sample > start_sample:
                        audio_segment = full_audio[start_sample:end_sample]

                        # Only include segments with sufficient length and text
                        if len(audio_segment) > sr * 0.5 and text.strip():  # At least 0.5 seconds
                            dataset_items.append({
                                "audio": {"array": audio_segment, "sampling_rate": 16000},
                                "text": text.strip()
                            })
                            processed += 1

            except Exception as e:
                print(f"⚠️  Error processing VTT {vtt_file.name}: {e}")

        # Fallback to TXT file if no VTT or VTT failed
        elif txt_file.exists():
            try:
                # Load audio (keep raw for compatibility with existing training)
                audio, sr = librosa.load(str(wav_file), sr=16000)

                # Load text
                with open(txt_file, 'r', encoding='utf-8') as f:
                    text = f.read().strip()

                if text and len(audio) > 16000:  # At least 1 second
                    dataset_items.append({
                        "audio": {"array": audio, "sampling_rate": 16000},
                        "text": text
                    })
                    processed += 1

            except Exception as e:
                print(f"⚠️  Error processing {wav_file.name}: {e}")

        if processed % 20 == 0 and processed > 0:
            print(f"   Processed {processed} examples...")

    if not dataset_items:
        print("❌ No valid audio-text pairs created!")
        return None

    # Create HuggingFace dataset
    dataset = Dataset.from_list(dataset_items)
    print(f"✅ Created dataset with {len(dataset)} examples from {len(wav_files)} audio files")

    return dataset

def parse_vtt_segments(vtt_content):
    """Parse VTT content into individual segments with timestamps"""

    segments = []
    lines = vtt_content.split('\n')

    i = 0
    while i < len(lines):
        line = lines[i].strip()

        # Skip headers and empty lines
        if (not line or
            line.startswith('WEBVTT') or
            line.startswith('Kind:') or
            line.startswith('Language:') or
            line.isdigit()):
            i += 1
            continue

        # Look for timestamp line
        if '-->' in line:
            # Parse timestamps
            time_match = re.match(r'(\d{2}:\d{2}:\d{2}\.\d{3})\s*-->\s*(\d{2}:\d{2}:\d{2}\.\d{3})', line)
            if time_match:
                start_time_str = time_match.group(1)
                end_time_str = time_match.group(2)

                try:
                    start_time = convert_vtt_time_to_seconds(start_time_str)
                    end_time = convert_vtt_time_to_seconds(end_time_str)

                    # Collect text lines for this segment
                    text_lines = []
                    i += 1

                    while i < len(lines) and lines[i].strip() and '-->' not in lines[i]:
                        text_line = lines[i].strip()
                        cleaned_text = clean_vtt_formatting(text_line)
                        if cleaned_text:
                            text_lines.append(cleaned_text)
                        i += 1

                    if text_lines:
                        full_text = ' '.join(text_lines)
                        if full_text.strip():
                            segments.append({
                                'start': start_time,
                                'end': end_time,
                                'text': full_text.strip()
                            })

                except Exception as e:
                    print(f"⚠️  Error parsing timestamp: {line}")
                    i += 1
            else:
                i += 1
        else:
            i += 1

    return segments

def convert_vtt_time_to_seconds(time_str):
    """Convert VTT timestamp to seconds"""
    # Handle format like "00:01:39.301"
    parts = time_str.replace(',', '.').split(':')
    hours = int(parts[0])
    minutes = int(parts[1])
    seconds = float(parts[2])
    return hours * 3600 + minutes * 60 + seconds

# ============================================
# MAIN DATA PREPARATION
# ============================================

# Create the dataset that the training cell expects
print("🚀 DATA PREPARATION WITH GOOGLE DRIVE SUPPORT")
print("=" * 55)

# Setup data source and create dataset
if USE_GOOGLE_DRIVE_DIRECTLY:
    print("🎯 QUICK MODE: Using Google Drive directly")
    print("💡 To use menu instead, set USE_GOOGLE_DRIVE_DIRECTLY = False")
    dataset = load_from_google_drive()
else:
    dataset = setup_data_source()

if dataset is not None:
    print(f"\n✅ DATASET READY FOR TRAINING!")
    print(f"📊 Total examples: {len(dataset)}")
    print(f"📝 First example text: {dataset[0]['text'][:100]}...")
    print(f"🎵 Audio sample rate: {dataset[0]['audio']['sampling_rate']}")
    print(f"⏱️  Audio length: {len(dataset[0]['audio']['array']) / dataset[0]['audio']['sampling_rate']:.1f}s")
    print("\n🎯 Ready to run your training cell!")
else:
    print("❌ Failed to create dataset")
    print("💡 Try running this cell again and choose a different option")

print("\n" + "="*55)
print("💡 WHAT'S NEXT:")
print("   If dataset loaded successfully, run your training cell!")
print("   The training cell will use the 'dataset' variable created here.")

🚀 DATA PREPARATION WITH GOOGLE DRIVE SUPPORT
🎯 QUICK MODE: Using Google Drive directly
💡 To use menu instead, set USE_GOOGLE_DRIVE_DIRECTLY = False
🔗 LOADING FROM GOOGLE DRIVE...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Google Drive mounted
📁 Available batches:
  1. training_batch_2025-07-26_14-22-20 (2025-07-26 14-22-20)
  2. training_batch_2025-07-26_08-32-16 (2025-07-26 08-32-16)
  3. training_batch_2025-07-26_08-30-35 (2025-07-26 08-30-35)
  4. training_batch_2025-07-26_08-00-57 (2025-07-26 08-00-57)
  5. training_batch_2025-07-26_08-00-10 (2025-07-26 08-00-10)
Enter batch number (1-5) or press Enter for newest: 
🎯 Using newest batch: training_batch_2025-07-26_14-22-20
📥 Loading batch: training_batch_2025-07-26_14-22-20
📂 Copying to: /content/training_data
✅ Copied 53 files
📂 Creating dataset from: /content/training_data
📊 Found 28 WAV files, 24 VTT files, 1 TXT files
   📝 Found 110 VTT segmen

In [None]:
# ============================================
# CELL 3: Cycling Training (All examples, 50 at a time - FIXED VERSION)
# ============================================

import torch
from datasets import Dataset
from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    Seq2SeqTrainer, Seq2SeqTrainingArguments
)
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import os
import gc
import random
import shutil

# Disk space management
def check_disk_space():
    """Check available disk space and clean if needed"""
    import subprocess
    result = subprocess.run(['df', '-h', '/content/drive/MyDrive'], capture_output=True, text=True)
    lines = result.stdout.strip().split('\n')
    if len(lines) > 1:
        parts = lines[1].split()
        if len(parts) >= 4:
            used = parts[2]
            available = parts[3]
            print(f"💾 Disk usage: {used} used, {available} available")

            # Only warn if less than 2GB available (more realistic threshold)
            if 'G' in available:
                available_gb = float(available.replace('G', ''))
                if available_gb < 2.0:  # Changed from 5.0 to 2.0
                    print("⚠️  Low disk space! Cleaning up...")
                    cleanup_disk_space()
                    return False
                elif available_gb < 5.0:
                    print(f"⚠️  Moderate disk space: {available_gb:.1f} GB available")
                else:
                    print(f"✅ Good disk space: {available_gb:.1f} GB available")
            return True
    return True

def cleanup_disk_space():
    """Clean up disk space"""
    print("🧹 Cleaning up disk space...")

    # Clear temporary files
    temp_dirs = ['/tmp', '/content/.cache', '/root/.cache']
    for temp_dir in temp_dirs:
        if os.path.exists(temp_dir):
            try:
                shutil.rmtree(temp_dir)
                os.makedirs(temp_dir, exist_ok=True)
                print(f"✅ Cleared {temp_dir}")
            except Exception as e:
                print(f"⚠️  Could not clear {temp_dir}: {e}")

    # Clear pip cache
    os.system('pip cache purge')

    # Clear conda cache
    os.system('conda clean -a -y 2>/dev/null || echo "No conda"')

    # Force garbage collection
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print("✅ Disk cleanup complete!")

def check_memory_usage():
    """Check current memory usage"""
    try:
        import psutil
        memory = psutil.virtual_memory()
        used_gb = memory.used / (1024**3)
        total_gb = memory.total / (1024**3)
        available_gb = memory.available / (1024**3)

        print(f"💾 RAM Usage: {used_gb:.1f}GB used, {available_gb:.1f}GB available, {total_gb:.1f}GB total")

        # Warn if memory usage is high
        if available_gb < 5.0:
            print("⚠️  Low RAM available! Consider reducing chunk size.")
            return False
        elif available_gb < 10.0:
            print("⚠️  Moderate RAM available. Monitoring closely...")
        else:
            print("✅ Good RAM availability")
        return True
    except ImportError:
        print("⚠️  psutil not available, skipping memory check")
        return True

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels since they have to be of different lengths and need different padding methods
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # If bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

def create_cycling_dataset(chunk_size=25, max_epochs=1):
    """Create a dataset that streams data in small chunks - ULTRA MEMORY EFFICIENT"""

    # Check disk space first
    if not check_disk_space():
        print("❌ Insufficient disk space!")
        return None

    # Check if dataset exists and is accessible
    if 'dataset' not in globals() or dataset is None:
        print("❌ No dataset available!")
        return None

    # Load processor for preprocessing
    print("📥 Loading processor for preprocessing...")
    model_name = "openai/whisper-small"
    processor = WhisperProcessor.from_pretrained(model_name)

    try:
        # Test if we can access the dataset
        total_examples = len(dataset)
        print(f"📊 Total examples: {total_examples}")
    except Exception as e:
        print(f"❌ Error accessing dataset: {e}")
        return None

    print(f"🔄 Creating ULTRA memory-efficient streaming dataset...")
    print(f"💾 Using micro-chunk size: {chunk_size}")
    print(f"📦 Will process in very small batches to avoid RAM overflow")

    # Create a streaming dataset that processes in tiny chunks
    all_examples = []

    # Get all examples for this epoch
    indices = list(range(total_examples))
    random.shuffle(indices)  # Shuffle for randomness

    # Process in MUCH smaller chunks to avoid RAM overflow
    for i in range(0, len(indices), chunk_size):
        chunk_indices = indices[i:i + chunk_size]
        try:
            chunk_data = dataset.select(chunk_indices)

            print(f"   📦 Processing micro-chunk {i//chunk_size + 1}/{(len(indices)-1)//chunk_size + 1} ({len(chunk_data)} examples)")

            # Process each example individually to minimize memory
            for j, example in enumerate(chunk_data):
                # Preprocess the example to add input_features
                try:
                    # Process audio to get input features
                    audio = example["audio"]
                    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
                    input_features = inputs.input_features.squeeze(0)  # Remove batch dimension

                    # Process text labels
                    labels = processor.tokenizer(example["text"], return_tensors="pt").input_ids.squeeze(0)

                    # Create preprocessed example
                    preprocessed_example = {
                        "input_features": input_features,
                        "labels": labels
                    }

                    all_examples.append(preprocessed_example)

                    # Clear individual example from memory immediately
                    del example, preprocessed_example, input_features, labels, inputs

                except Exception as e:
                    print(f"⚠️  Error preprocessing example {j}: {e}")
                    continue

                # Force garbage collection every 10 examples
                if j % 10 == 0:
                    gc.collect()
                    torch.cuda.empty_cache()

            # Clear chunk data immediately
            del chunk_data
            gc.collect()
            torch.cuda.empty_cache()

            # Check disk space and memory every 20 chunks
            if (i // chunk_size) % 20 == 0:
                check_disk_space()
                check_memory_usage()

            # Force memory cleanup every 50 chunks
            if (i // chunk_size) % 50 == 0:
                print(f"🧹 Memory cleanup at chunk {i//chunk_size + 1}")
                gc.collect()
                torch.cuda.empty_cache()
                if torch.cuda.is_available():
                    torch.cuda.synchronize()

        except Exception as chunk_error:
            print(f"⚠️  Error processing chunk {i//chunk_size + 1}: {chunk_error}")
            print("🔄 Skipping this chunk and continuing...")
            continue

    # Create the final dataset with just one epoch worth of examples
    try:
        print(f"📦 Creating final dataset from {len(all_examples)} examples...")
        cycling_dataset = Dataset.from_list(all_examples)
        print(f"✅ Created ultra memory-efficient dataset with {len(cycling_dataset)} examples (1 epoch)")
        print(f"📊 Original examples: {total_examples}, Processed: {len(cycling_dataset)}")

        # Clear the large list to free memory
        del all_examples
        gc.collect()
        torch.cuda.empty_cache()

        return cycling_dataset
    except Exception as e:
        print(f"❌ Failed to create cycling dataset: {e}")
        return None

def train_whisper_cycling():
    """Train Whisper with cycling through all examples"""

    print("🚀 Starting Cycling Training (All examples, 50 at a time)...")

    # Check GPU availability
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"🖥️  Using device: {device}")
    if torch.cuda.is_available():
        print(f"   GPU: {torch.cuda.get_device_name()}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        print(f"   CUDA Version: {torch.version.cuda}")
        print(f"   GPU Memory Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        print("❌ No GPU detected! Training will be very slow on CPU")
        print("💡 Make sure you're using a GPU runtime in Colab")

    # Test GPU functionality
    if torch.cuda.is_available():
        print("🧪 Testing GPU functionality...")
        test_tensor = torch.randn(1000, 1000).cuda()
        result = torch.matmul(test_tensor, test_tensor)
        print(f"✅ GPU test successful! Result shape: {result.shape}")
        del test_tensor, result
        torch.cuda.empty_cache()

    # Create cycling dataset
    print("📊 Creating A100-optimized cycling dataset...")
    cycling_dataset = create_cycling_dataset(chunk_size=100, max_epochs=1)  # Much larger chunks for A100

    if cycling_dataset is None:
        print("❌ Failed to create cycling dataset!")
        print("💡 Try re-running the data preparation cell first")
        return

    # Load model and processor
    model_name = "openai/whisper-small"
    print(f"📥 Loading model: {model_name}")
    processor = WhisperProcessor.from_pretrained(model_name)
    model = WhisperForConditionalGeneration.from_pretrained(model_name)

    # Move model to GPU if available
    if device == "cuda":
        print(f"🖥️  Moving model to GPU: {torch.cuda.get_device_name()}")
        model = model.to(device)
        print(f"✅ Model moved to GPU")
        print(f"🔍 GPU memory after model load: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    else:
        print("⚠️  No GPU available, using CPU")

    # Note: Preprocessing is now done during dataset creation
    # No need for separate preprocessing function

    # Create output directory in Google Drive
    model_output_dir = "/content/drive/MyDrive/sermon_data/models/whisper-sermon-cycling-50"
    os.makedirs(model_output_dir, exist_ok=True)

    # Training arguments - A100 AGGRESSIVE VERSION
    training_args = Seq2SeqTrainingArguments(
        output_dir=model_output_dir,
        per_device_train_batch_size=8,  # Much larger batch size for A100
        gradient_accumulation_steps=2,  # Effective batch size of 16
        num_train_epochs=1,  # Single epoch to save memory
        logging_steps=5,  # Log more frequently
        save_steps=50,  # Save every 50 steps (about 15 minutes)
        eval_steps=50,
        warmup_steps=100,  # More warmup for larger batch
        save_total_limit=5,  # Keep more checkpoints
        save_strategy="steps",  # Save by steps, not epochs
        predict_with_generate=True,
        fp16=True,  # Use mixed precision for GPU efficiency
        dataloader_pin_memory=True,  # Enable pin memory for GPU
        remove_unused_columns=False,
        report_to="none",
        dataloader_num_workers=4,  # More workers for A100
        gradient_checkpointing=False,  # Disable for better GPU utilization
        max_grad_norm=1.0,
        learning_rate=2e-5,  # Slightly higher learning rate for A100
        group_by_length=False,
        # A100 optimizations
        dataloader_drop_last=False,
        load_best_model_at_end=False,
        # A100 memory optimizations
        dataloader_prefetch_factor=4,  # More prefetch for A100
        # Let A100 handle memory - no limits
    )

    # Data collator for Whisper
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=model.config.decoder_start_token_id,
    )

    # Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=cycling_dataset,
        data_collator=data_collator,
        tokenizer=processor.tokenizer,
    )

    # Start training!
    print("🎯 Starting A100 AGGRESSIVE cycling training...")
    print(f"📊 Will cycle through examples in chunks of 100")
    print(f"📊 Dataset size: {len(cycling_dataset)} examples (1 epoch)")
    print(f"🖥️  GPU: {torch.cuda.get_device_name() if torch.cuda.is_available() else 'CPU'}")
    print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB" if torch.cuda.is_available() else "No GPU")
    print(f"🚀 Batch size: 8, Effective batch size: 16")
    print(f"⚡ Learning rate: 2e-5 (aggressive)")

    # Monitor GPU memory before training
    if torch.cuda.is_available():
        print(f"🔍 Initial GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated")
        print(f"🔍 Initial GPU memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB reserved")

    # Add memory monitoring during training
    try:
        # Monitor GPU memory during training
        if torch.cuda.is_available():
            print(f"🔍 GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated")

        trainer.train()

        # Monitor GPU memory after training
        if torch.cuda.is_available():
            print(f"🔍 GPU memory after training: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated")

    except Exception as e:
        print(f"⚠️  Training error: {e}")
        print("🧹 Emergency memory cleanup...")
        gc.collect()
        torch.cuda.empty_cache()
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        raise e

    # Clear memory after training
    print("🧹 Training complete - final memory cleanup...")
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    # Save model to Google Drive
    print("💾 Saving model to Google Drive...")
    trainer.save_model()
    print(f"✅ Model saved to: {model_output_dir}")

    return model_output_dir

# Run the cycling training directly
# The training will now save every 10 steps (about 5 minutes)
# This means you'll never lose more than 5 minutes of progress!
model_path = train_whisper_cycling()
""

In [None]:
# ============================================
# CELL 3: Disk-Space Optimized Training
# ============================================

import torch
from datasets import Dataset
from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    Seq2SeqTrainer, Seq2SeqTrainingArguments
)
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import os
import gc
import shutil

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        try:
            # Filter out None features
            valid_features = [f for f in features if f is not None and "input_features" in f and "labels" in f]

            if not valid_features:
                raise ValueError("No valid features in batch")

            input_features = [{"input_features": feature["input_features"]} for feature in valid_features]
            label_features = [{"input_ids": feature["labels"]} for feature in valid_features]

            # Pad input features
            batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

            # Pad labels
            labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

            # Replace padding with -100 to ignore loss correctly
            labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

            # Remove decoder start token if present
            if labels.size(1) > 0 and (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
                labels = labels[:, 1:]

            batch["labels"] = labels

            return batch

        except Exception as e:
            print(f"❌ Error in data collator: {e}")
            # Create dummy batch to avoid crash
            dummy_input = torch.zeros((1, 80, 3000))  # Whisper input shape
            dummy_labels = torch.full((1, 10), -100, dtype=torch.long)
            return {
                "input_features": dummy_input,
                "labels": dummy_labels
            }

def check_disk_space():
    """Check available disk space"""
    import subprocess
    try:
        result = subprocess.run(['df', '-h', '/content/drive/MyDrive'], capture_output=True, text=True)
        lines = result.stdout.strip().split('\n')
        if len(lines) > 1:
            parts = lines[1].split()
            if len(parts) >= 4:
                available = parts[3]
                print(f"💾 Available disk space: {available}")

                if 'G' in available:
                    available_gb = float(available.replace('G', ''))
                    if available_gb < 3.0:
                        print("⚠️  LOW DISK SPACE! Consider stopping training.")
                        return False
                    elif available_gb < 5.0:
                        print("⚠️  Moderate disk space remaining")

                return True
    except:
        print("⚠️  Could not check disk space")
    return True

def cleanup_old_checkpoints(output_dir, keep_latest=1):
    """Clean up old checkpoints to save space"""
    try:
        checkpoint_dirs = [d for d in os.listdir(output_dir) if d.startswith('checkpoint-')]
        if len(checkpoint_dirs) > keep_latest:
            # Sort by checkpoint number
            checkpoint_dirs.sort(key=lambda x: int(x.split('-')[1]))

            # Remove all but the latest
            for checkpoint in checkpoint_dirs[:-keep_latest]:
                checkpoint_path = os.path.join(output_dir, checkpoint)
                if os.path.isdir(checkpoint_path):
                    shutil.rmtree(checkpoint_path)
                    print(f"🗑️  Removed old checkpoint: {checkpoint}")
    except Exception as e:
        print(f"⚠️  Error cleaning checkpoints: {e}")

def train_with_minimal_disk_usage():
    """Disk-space optimized training"""

    print("🚀 DISK-SPACE OPTIMIZED TRAINING")
    print("=" * 45)

    # Check if dataset exists
    if 'dataset' not in globals() or dataset is None:
        print("❌ No dataset found!")
        print("💡 Run the data prep cell first")
        return None

    print(f"📊 Dataset ready: {len(dataset)} examples")

    # Check initial disk space
    if not check_disk_space():
        print("❌ Insufficient disk space to start training")
        return None

    # GPU setup
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"🖥️  Device: {device}")
    if torch.cuda.is_available():
        print(f"   GPU: {torch.cuda.get_device_name()}")

    # Load model and processor
    model_name = "openai/whisper-small"
    print(f"📥 Loading {model_name}...")
    processor = WhisperProcessor.from_pretrained(model_name)
    model = WhisperForConditionalGeneration.from_pretrained(model_name)

    if device == "cuda":
        model = model.to(device)
        print("✅ Model moved to GPU")

    # Preprocess dataset with validation
    def preprocess_function(examples):
        try:
            audio = examples["audio"]
            text = examples["text"]

            # Handle dict format from data prep
            if isinstance(audio, dict) and "array" in audio:
                audio_data = audio["array"]
                sample_rate = audio.get("sampling_rate", 16000)
            else:
                audio_data = audio
                sample_rate = 16000

            # Validate audio data
            if len(audio_data) == 0:
                print(f"⚠️  Skipping empty audio")
                return None

            # Ensure audio is proper length (pad/trim if needed)
            target_length = 30 * sample_rate  # 30 seconds max
            if len(audio_data) > target_length:
                audio_data = audio_data[:target_length]  # Trim if too long
            elif len(audio_data) < sample_rate * 0.5:  # Less than 0.5 seconds
                print(f"⚠️  Skipping too short audio ({len(audio_data)/sample_rate:.1f}s)")
                return None

            # Validate text
            if not text or len(text.strip()) == 0:
                print(f"⚠️  Skipping empty text")
                return None

            # Process audio with error handling
            inputs = processor(
                audio_data,
                sampling_rate=sample_rate,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=3000  # Limit audio length
            )
            input_features = inputs.input_features.squeeze(0)

            # Process text with length validation
            text = text.strip()
            if len(text) > 1000:  # Limit text length
                text = text[:1000]

            labels = processor.tokenizer(
                text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=448  # Whisper's max length
            ).input_ids.squeeze(0)

            # Validate output shapes
            if input_features.numel() == 0 or labels.numel() == 0:
                print(f"⚠️  Skipping invalid features/labels")
                return None

            return {
                "input_features": input_features,
                "labels": labels
            }

        except Exception as e:
            print(f"⚠️  Error preprocessing example: {e}")
            return None

    print("🔄 Preprocessing dataset...")
    processed_dataset = dataset.map(
        preprocess_function,
        remove_columns=dataset.column_names,
        desc="Preprocessing"
    )

    # Filter out None values from failed preprocessing
    processed_dataset = processed_dataset.filter(lambda x: x is not None)

    if len(processed_dataset) == 0:
        print("❌ No valid examples after preprocessing!")
        return None

    print(f"✅ Preprocessed {len(processed_dataset)} valid examples")

    # Create output directory
    output_dir = "/content/drive/MyDrive/models/whisper-sermon-minimal"
    os.makedirs(output_dir, exist_ok=True)

    # DISK-SPACE OPTIMIZED Training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=2,
        num_train_epochs=3,
        logging_steps=1,

        # DISK SPACE OPTIMIZATIONS:
        save_steps=10,              # Save every 10 steps
        save_total_limit=1,         # Keep ONLY 1 checkpoint (saves space!)
        save_strategy="steps",
        save_only_model=True,       # Don't save optimizer states (saves ~50% space)

        predict_with_generate=True,
        fp16=True,
        remove_unused_columns=False,
        report_to="none",
        learning_rate=2e-5,         # Slightly higher LR for your VTT segments

        # Memory optimizations
        dataloader_pin_memory=False,  # Save memory
        gradient_checkpointing=True,  # Save GPU memory
    )

    # Data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=model.config.decoder_start_token_id,
    )

    # Custom trainer with disk monitoring
    class DiskOptimizedTrainer(Seq2SeqTrainer):
        def on_save(self, args, state, control, **kwargs):
            # Clean up old checkpoints after each save
            cleanup_old_checkpoints(args.output_dir, keep_latest=1)

            # Check disk space
            if not check_disk_space():
                print("⚠️  LOW DISK SPACE - Consider stopping training manually")

            # Force garbage collection
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Trainer
    trainer = DiskOptimizedTrainer(
        model=model,
        args=training_args,
        train_dataset=processed_dataset,
        data_collator=data_collator,
        tokenizer=processor.tokenizer,
    )

    # Start training
    print(f"🎯 Starting DISK-OPTIMIZED training...")
    print(f"📊 Examples: {len(processed_dataset)}")
    print(f"🚀 Epochs: 3, Batch size: 4")
    print(f"💾 SAVES: Every 10 steps, keeps only 1 checkpoint")
    print(f"⚡ Learning rate: 2e-5")

    try:
        trainer.train()

        # Final save and cleanup
        trainer.save_model()
        cleanup_old_checkpoints(output_dir, keep_latest=0)  # Remove all checkpoints, keep final model

        print(f"✅ Training complete! Model saved to: {output_dir}")
        print(f"🗑️  All intermediate checkpoints cleaned up")

        return output_dir
    except Exception as e:
        print(f"❌ Training error: {e}")
        # Emergency cleanup
        cleanup_old_checkpoints(output_dir, keep_latest=1)
        return None

# Start disk-optimized training
model_path = train_with_minimal_disk_usage()

🚀 DISK-SPACE OPTIMIZED TRAINING
📊 Dataset ready: 3264 examples
💾 Available disk space: 8.2G
🖥️  Device: cuda
   GPU: NVIDIA A100-SXM4-40GB
📥 Loading openai/whisper-small...
✅ Model moved to GPU
🔄 Preprocessing dataset...


Preprocessing:   0%|          | 0/3264 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [None]:
# ============================================
# COLAB CELL: WHISPER TRAINING WITH GOOGLE DRIVE SUPPORT
# ============================================

import torch
from datasets import Dataset
from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    Seq2SeqTrainer, Seq2SeqTrainingArguments
)
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import os
import gc
import random
import shutil
from pathlib import Path
import re
import librosa

# Google Drive support
from google.colab import drive

# ============================================
# GOOGLE DRIVE DATA LOADER
# ============================================

def load_training_data_from_drive(batch_name=None, list_only=False):
    """
    Load training data from Google Drive batches (from your YouTube processor)

    Args:
        batch_name: Specific batch name or None for newest
        list_only: If True, just list available batches

    Returns:
        Path to local data directory or None
    """

    print("🔗 GOOGLE DRIVE DATA LOADER")
    print("=" * 40)

    # Mount Google Drive
    try:
        drive.mount('/content/drive')
        print("✅ Google Drive mounted")
    except Exception as e:
        print(f"❌ Failed to mount Google Drive: {e}")
        return None

    # Define Google Drive path
    drive_base = "/content/drive/MyDrive/Whisper_Training_Data"

    if not os.path.exists(drive_base):
        print("❌ No Whisper_Training_Data folder found in Google Drive")
        print("💡 Make sure your YouTube processor has uploaded batches first")
        return None

    # List available batches
    batches = []
    for item in os.listdir(drive_base):
        item_path = os.path.join(drive_base, item)
        if os.path.isdir(item_path) and item.startswith('training_batch_'):
            batches.append(item)

    batches = sorted(batches, reverse=True)  # Newest first

    if not batches:
        print("📂 No training batches found")
        return None

    # If listing only, show available batches
    if list_only:
        print("📁 AVAILABLE TRAINING BATCHES:")
        print("-" * 30)

        for i, batch in enumerate(batches, 1):
            batch_path = os.path.join(drive_base, batch)

            # Count files
            try:
                vtt_files = len(list(Path(batch_path).glob("*.vtt")))
                wav_files = len(list(Path(batch_path).glob("*.wav")))
                timestamp = batch.replace('training_batch_', '').replace('_', ' ')

                print(f"{i}. {batch}")
                print(f"   📅 Created: {timestamp}")
                print(f"   📊 Files: {vtt_files} VTT, {wav_files} WAV")
                print()
            except:
                print(f"{i}. {batch} (error reading)")

        return None

    # Select batch to use
    if batch_name and batch_name in batches:
        selected_batch = batch_name
    else:
        selected_batch = batches[0]  # Use newest
        if not batch_name:
            print(f"🎯 Auto-selecting newest batch: {selected_batch}")

    # Copy batch to local Colab storage
    source_path = os.path.join(drive_base, selected_batch)
    local_path = "/content/training_data"

    print(f"📥 Loading batch: {selected_batch}")
    print(f"📂 Copying to: {local_path}")

    # Clear and create local directory
    if os.path.exists(local_path):
        shutil.rmtree(local_path)
    os.makedirs(local_path)

    # Copy files
    files_copied = 0
    for file_path in Path(source_path).iterdir():
        if file_path.is_file() and file_path.suffix.lower() in ['.wav', '.vtt', '.txt']:
            dest_file = Path(local_path) / file_path.name
            shutil.copy2(file_path, dest_file)
            files_copied += 1

    print(f"✅ Copied {files_copied} files")

    # Convert VTT files to TXT format
    convert_vtt_to_txt(local_path)

    return local_path

def convert_vtt_to_txt(data_dir):
    """Convert VTT files to clean TXT format for training"""

    data_path = Path(data_dir)
    vtt_files = list(data_path.glob("*.vtt"))

    if not vtt_files:
        return

    print(f"📝 Converting {len(vtt_files)} VTT files to TXT...")

    converted = 0
    for vtt_file in vtt_files:
        txt_file = vtt_file.with_suffix('.txt')

        # Skip if TXT already exists
        if txt_file.exists():
            continue

        try:
            with open(vtt_file, 'r', encoding='utf-8') as f:
                vtt_content = f.read()

            # Extract clean text
            clean_text = extract_text_from_vtt(vtt_content)

            if clean_text.strip():
                with open(txt_file, 'w', encoding='utf-8') as f:
                    f.write(clean_text)
                converted += 1

        except Exception as e:
            print(f"⚠️  Error converting {vtt_file.name}: {e}")

    if converted > 0:
        print(f"✅ Converted {converted} VTT files to TXT")

def extract_text_from_vtt(vtt_content):
    """Extract clean text from VTT content"""

    text_lines = []
    lines = vtt_content.split('\n')

    for line in lines:
        line = line.strip()

        # Skip VTT headers, timestamps, and empty lines
        if (line and
            not line.startswith('WEBVTT') and
            not line.startswith('Kind:') and
            not line.startswith('Language:') and
            '-->' not in line and
            not line.isdigit() and
            not line.startswith('align:')):

            # Clean VTT formatting
            cleaned = clean_vtt_formatting(line)
            if cleaned:
                text_lines.append(cleaned)

    return ' '.join(text_lines)

def clean_vtt_formatting(text):
    """Remove VTT formatting tags"""

    # Remove <c> tags and inline timestamps
    text = re.sub(r'</?c>', '', text)
    text = re.sub(r'<\d{2}:\d{2}:\d{2}\.\d{3}>', '', text)
    text = re.sub(r'align:start position:\d+%', '', text)
    text = re.sub(r'\s+', ' ', text)

    return text.strip()

# ============================================
# DATASET CREATION
# ============================================

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

def create_dataset_from_path(data_path, processor, chunk_size=50):
    """Create training dataset from local data path"""

    print(f"📂 Creating dataset from: {data_path}")

    # Find audio and text files
    data_dir = Path(data_path)
    wav_files = list(data_dir.glob("*.wav"))
    txt_files = list(data_dir.glob("*.txt"))

    print(f"📊 Found {len(wav_files)} WAV files and {len(txt_files)} TXT files")

    if not wav_files or not txt_files:
        print("❌ No matching audio/text files found!")
        return None

    # Create dataset items
    dataset_items = []
    processed = 0

    for wav_file in wav_files:
        # Find matching text file
        base_name = wav_file.stem
        txt_file = data_dir / f"{base_name}.txt"

        if txt_file.exists():
            try:
                # Load audio
                audio, sr = librosa.load(str(wav_file), sr=16000)

                # Load text
                with open(txt_file, 'r', encoding='utf-8') as f:
                    text = f.read().strip()

                if text and len(audio) > 16000:  # At least 1 second
                    # Process audio to get input features
                    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
                    input_features = inputs.input_features.squeeze(0)

                    # Process text labels
                    labels = processor.tokenizer(text, return_tensors="pt").input_ids.squeeze(0)

                    dataset_items.append({
                        "input_features": input_features,
                        "labels": labels
                    })

                    processed += 1

                    if processed % 10 == 0:
                        print(f"   Processed {processed} files...")

            except Exception as e:
                print(f"⚠️  Error processing {wav_file.name}: {e}")

    if not dataset_items:
        print("❌ No valid audio-text pairs created!")
        return None

    # Create HuggingFace dataset
    dataset = Dataset.from_list(dataset_items)
    print(f"✅ Created dataset with {len(dataset)} examples")

    return dataset

# ============================================
# TRAINING FUNCTIONS
# ============================================

def train_whisper_with_data(data_path=None):
    """Train Whisper model with data from specified path"""

    print("🚀 WHISPER TRAINING WITH GOOGLE DRIVE SUPPORT")
    print("=" * 55)

    # Check GPU
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"🖥️  Device: {device}")
    if torch.cuda.is_available():
        print(f"   GPU: {torch.cuda.get_device_name()}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

    # Load model and processor
    model_name = "openai/whisper-small"
    print(f"📥 Loading model: {model_name}")
    processor = WhisperProcessor.from_pretrained(model_name)
    model = WhisperForConditionalGeneration.from_pretrained(model_name)

    if device == "cuda":
        model = model.to(device)
        print("✅ Model moved to GPU")

    # Create dataset
    dataset = create_dataset_from_path(data_path, processor)
    if dataset is None:
        return None

    # Create output directory
    model_output_dir = "/content/drive/MyDrive/sermon_data/models/whisper-sermon-drive"
    os.makedirs(model_output_dir, exist_ok=True)

    # Training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=model_output_dir,
        per_device_train_batch_size=8,
        gradient_accumulation_steps=2,
        num_train_epochs=1,
        logging_steps=5,
        save_steps=50,
        save_total_limit=3,
        save_strategy="steps",
        predict_with_generate=True,
        fp16=True,
        dataloader_pin_memory=True,
        remove_unused_columns=False,
        report_to="none",
        dataloader_num_workers=2,
        learning_rate=2e-5,
        group_by_length=False,
    )

    # Data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=model.config.decoder_start_token_id,
    )

    # Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=data_collator,
        tokenizer=processor.tokenizer,
    )

    # Start training
    print(f"🎯 Starting training...")
    print(f"📊 Dataset size: {len(dataset)} examples")
    print(f"🚀 Batch size: 8, Effective batch size: 16")

    try:
        trainer.train()
        trainer.save_model()
        print(f"✅ Training complete! Model saved to: {model_output_dir}")
        return model_output_dir

    except Exception as e:
        print(f"❌ Training error: {e}")
        gc.collect()
        torch.cuda.empty_cache()
        return None

# ============================================
# EASY TRAINING FUNCTIONS
# ============================================

def quick_train_newest():
    """Quick training with newest Google Drive batch"""
    data_path = load_training_data_from_drive()
    if data_path:
        return train_whisper_with_data(data_path)
    else:
        print("❌ No Google Drive data available")
        return None

def list_and_train():
    """List batches then train with selected one"""
    print("📋 Available Google Drive batches:")
    load_training_data_from_drive(list_only=True)

    batch_name = input("\nEnter batch name (or press Enter for newest): ").strip()
    data_path = load_training_data_from_drive(batch_name if batch_name else None)

    if data_path:
        return train_whisper_with_data(data_path)
    else:
        print("❌ Failed to load data")
        return None

def train_with_specific_batch(batch_name):
    """Train with specific Google Drive batch"""
    data_path = load_training_data_from_drive(batch_name)
    if data_path:
        return train_whisper_with_data(data_path)
    else:
        print(f"❌ Failed to load batch: {batch_name}")
        return None

# ============================================
# USAGE INSTRUCTIONS
# ============================================

print("🚀 COLAB WHISPER TRAINING WITH GOOGLE DRIVE LOADED!")
print("\n💡 USAGE OPTIONS:")
print("   quick_train_newest()                    # Train with newest Google Drive batch")
print("   list_and_train()                       # List batches then choose one")
print("   train_with_specific_batch('batch_name') # Train with specific batch")
print("   load_training_data_from_drive(list_only=True)  # Just list available batches")
print("\n🎯 Quick start (recommended):")
print("   quick_train_newest()")

# Uncomment to start training immediately:
# model_path = quick_train_newest()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# ============================================
# CELL 4: Test the Trained Model
# ============================================

import os
import torch
import json
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import librosa

def create_preprocessor_config(model_path):
    """Create missing preprocessor_config.json file"""
    try:
        preprocessor_config = {
            "feature_extractor_type": "WhisperFeatureExtractor",
            "processor_class": "WhisperProcessor",
            "tokenizer_class": "WhisperTokenizer"
        }

        config_path = os.path.join(model_path, "preprocessor_config.json")
        with open(config_path, 'w') as f:
            json.dump(preprocessor_config, f, indent=2)

        print(f"✅ Created preprocessor_config.json at {config_path}")
        return True
    except Exception as e:
        print(f"❌ Failed to create preprocessor_config.json: {e}")
        return False

def test_trained_model():
    """Test the trained model on a sample"""

    print("🧪 TESTING TRAINED MODEL")
    print("=" * 50)

    # Try different possible model paths
    possible_paths = [
        "/content/drive/MyDrive/sermon_data/models/whisper-sermon-cycling-50",
        "/content/drive/MyDrive/sermon_data/models",
        "/content/drive/MyDrive/sermon_data/models/whisper-sermon-cycling-50/checkpoint-*",
    ]

    model_path = None
    processor = None
    model = None

    # Try to find and load the model
    for path in possible_paths:
        print(f"🔍 Checking path: {path}")

        if os.path.exists(path):
            print(f"✅ Path exists: {path}")

            # If it's a directory, check for model files
            if os.path.isdir(path):
                try:
                    files = os.listdir(path)
                    print(f"📋 Files in directory: {files}")

                    # Check if it contains model files
                    model_files = [f for f in files if f.endswith('.safetensors') or f.endswith('.bin')]
                    config_files = [f for f in files if f.endswith('config.json')]

                    if model_files and config_files:
                        print(f"✅ Found model files in: {path}")
                        try:
                            # Try to load processor from the saved model first
                            try:
                                processor = WhisperProcessor.from_pretrained(path)
                                print("✅ Loaded processor from saved model")
                            except Exception as proc_error:
                                print(f"⚠️  Could not load processor from saved model: {proc_error}")

                                # Try to create missing preprocessor_config.json
                                if "preprocessor_config.json" in str(proc_error):
                                    print("🔧 Creating missing preprocessor_config.json...")
                                    if create_preprocessor_config(path):
                                        try:
                                            processor = WhisperProcessor.from_pretrained(path)
                                            print("✅ Loaded processor after creating config")
                                        except Exception as retry_error:
                                            print(f"❌ Still failed after creating config: {retry_error}")
                                            print("🔄 Loading processor from original model...")
                                            processor = WhisperProcessor.from_pretrained("openai/whisper-small")
                                            print("✅ Loaded processor from original model")
                                    else:
                                        print("🔄 Loading processor from original model...")
                                        processor = WhisperProcessor.from_pretrained("openai/whisper-small")
                                        print("✅ Loaded processor from original model")
                                else:
                                    print("🔄 Loading processor from original model...")
                                    processor = WhisperProcessor.from_pretrained("openai/whisper-small")
                                    print("✅ Loaded processor from original model")

                            # Load the trained model
                            model = WhisperForConditionalGeneration.from_pretrained(path)
                            model_path = path
                            print("✅ Successfully loaded trained model!")
                            break
                        except Exception as load_error:
                            print(f"❌ Failed to load from {path}: {load_error}")
                            continue
                    else:
                        print(f"⚠️  No model files found in {path}")
                        continue

                except Exception as list_error:
                    print(f"❌ Could not list directory {path}: {list_error}")
                    continue
            else:
                print(f"⚠️  {path} is not a directory")
                continue
        else:
            print(f"❌ Path does not exist: {path}")

    if model is None:
        print("\n❌ Could not load trained model from any path!")
        print("💡 Possible issues:")
        print("   1. Training might still be in progress")
        print("   2. Model might be saved in a different location")
        print("   3. Model files might be corrupted")

        # Check if training is still running
        print("\n🔍 Checking for training processes...")
        try:
            import psutil
            for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
                if 'python' in proc.info['name'].lower():
                    cmdline = ' '.join(proc.info['cmdline']) if proc.info['cmdline'] else ''
                    if 'train' in cmdline.lower() or 'whisper' in cmdline.lower():
                        print(f"🔄 Training process found: PID {proc.info['pid']}")
        except ImportError:
            print("⚠️  psutil not available for process checking")

        return

    print(f"✅ Model loaded successfully from: {model_path}")

    # Test on a sample audio file
    test_audio_path = "/content/drive/MyDrive/sermon_data/sermon_1-MyKI27Kic.wav"

    if os.path.exists(test_audio_path):
        print(f"\n🎵 Testing on: {os.path.basename(test_audio_path)}")

        # Load audio
        audio, sr = librosa.load(test_audio_path, sr=16000)

        # Take first 30 seconds
        audio_sample = audio[:30*sr]

        # Move model to GPU if available
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if device == "cuda":
            model = model.to(device)
            print(f"🖥️  Model moved to GPU: {torch.cuda.get_device_name()}")

        # Transcribe
        inputs = processor(audio_sample, sampling_rate=16000, return_tensors="pt")

        # Move inputs to same device as model
        if device == "cuda":
            inputs = {k: v.to(device) for k, v in inputs.items()}

        # Generate
        with torch.no_grad():
            predicted_ids = model.generate(inputs["input_features"])

        # Decode
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

        print(f"📝 Your Model Transcription: {transcription}")

        # Compare with original Whisper
        print("\n🔄 Comparing with original Whisper...")
        original_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
        original_processor = WhisperProcessor.from_pretrained("openai/whisper-small")

        if device == "cuda":
            original_model = original_model.to(device)

        inputs_orig = original_processor(audio_sample, sampling_rate=16000, return_tensors="pt")
        if device == "cuda":
            inputs_orig = {k: v.to(device) for k, v in inputs_orig.items()}

        with torch.no_grad():
            predicted_ids_orig = original_model.generate(inputs_orig["input_features"])

        transcription_orig = original_processor.batch_decode(predicted_ids_orig, skip_special_tokens=True)[0]

        print(f"📝 Original Whisper: {transcription_orig}")
        print(f"🎯 Your Model: {transcription}")

        if transcription != transcription_orig:
            print("✅ Model has learned something different!")
            print("🎉 Training appears to be successful!")
        else:
            print("⚠️  Model output same as original")
            print("💡 This might indicate the model needs more training")

    else:
        print("❌ Test audio file not found!")
        print("🔍 Available audio files in sermon_data:")
        sermon_data_dir = "/content/drive/MyDrive/sermon_data"
        if os.path.exists(sermon_data_dir):
            try:
                for item in os.listdir(sermon_data_dir):
                    if item.endswith('.wav'):
                        print(f"   - {item}")
            except Exception as e:
                print(f"❌ Could not list directory: {e}")

# Run the test
test_trained_model()

🧪 TESTING TRAINED MODEL
🔍 Checking path: /content/drive/MyDrive/sermon_data/models/whisper-sermon-cycling-50
✅ Path exists: /content/drive/MyDrive/sermon_data/models/whisper-sermon-cycling-50
📋 Files in directory: ['checkpoint-1050', 'checkpoint-1100', 'checkpoint-1150', 'checkpoint-1200', 'checkpoint-1228', 'config.json', 'generation_config.json', 'model.safetensors', 'tokenizer_config.json', 'special_tokens_map.json', 'added_tokens.json', 'vocab.json', 'merges.txt', 'normalizer.json', 'training_args.bin', 'preprocessor_config.json']
✅ Found model files in: /content/drive/MyDrive/sermon_data/models/whisper-sermon-cycling-50
✅ Loaded processor from saved model
✅ Successfully loaded trained model!
✅ Model loaded successfully from: /content/drive/MyDrive/sermon_data/models/whisper-sermon-cycling-50

🎵 Testing on: sermon_1-MyKI27Kic.wav


Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English. This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details.
`generation_config` default values have been modified to match model-specific defaults: {'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161

📝 Your Model Transcription:  You know, I was just uh talking to we were talking to um John and Erica a few weeks ago and we were discussing that now how the 90s are a theme. Like, you know, they used to have like 70s parties and 80s parties. Well, now like my uh my niece went to a 90s themed party. And we were kind of laughing because when they think of uh

🔄 Comparing with original Whisper...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

📝 Original Whisper:  You know, I was just talking to, we were talking to John and Erica a few weeks ago, and we were discussing that now how the 90s are a theme. Like, you know, they used to have like 70s parties and 80s parties, well now like my niece went to a 90s themed party, and we were kind of laughing because when they think of,
🎯 Your Model:  You know, I was just uh talking to we were talking to um John and Erica a few weeks ago and we were discussing that now how the 90s are a theme. Like, you know, they used to have like 70s parties and 80s parties. Well, now like my uh my niece went to a 90s themed party. And we were kind of laughing because when they think of uh
✅ Model has learned something different!
🎉 Training appears to be successful!


In [None]:
   import gc
   import torch

   # Clear memory
   gc.collect()
   if torch.cuda.is_available():
       torch.cuda.empty_cache()

In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')

# List contents of your sermon_data folder
sermon_data_path = "/content/drive/MyDrive/sermon_data"
if os.path.exists(sermon_data_path):
    print("Contents of sermon_data folder:")
    for item in os.listdir(sermon_data_path):
        item_path = os.path.join(sermon_data_path, item)
        if os.path.isdir(item_path):
            print(f"📁 {item}/")
            # List contents of subdirectories
            try:
                for subitem in os.listdir(item_path):
                    print(f"   - {subitem}")
            except:
                pass
        else:
            print(f"�� {item}")
else:
    print("❌ sermon_data folder not found!")