In [None]:
# -*- coding: utf-8 -*-
"""agent-speech-model.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1JSW55ja-SiFXY9dTahib74r1dorBNL5G
"""

# ===============================================================================
# COMPLETE ACCENT CLASSIFICATION SCRIPT FOR GOOGLE COLAB (May 2025)
# ===============================================================================
# This script extracts audio from video URLs and classifies English accents
# Compatible with YouTube, Loom, and direct MP4 links
# Addresses common issues with YouTube bot detection and audio extraction
# ===============================================================================

# CELL 1: DEPENDENCY INSTALLATION
# Run this cell first - it installs all required packages and system dependencies
print("🚀 Starting complete dependency installation...")
print("This may take 2-3 minutes. Please wait for completion without interruption.")
print("=" * 70)

# Install Python packages with pinned torch version to avoid conflicts
!pip install --upgrade --quiet \
    speechbrain \
    torch==2.6.0 \
    torchaudio \
    gradio \
    yt-dlp \
    librosa \
    noisereduce \
    demucs \
    scipy \
    numpy \
    requests \
    pydub

# Install system dependencies
!apt-get update -qq
!apt-get install -y -qq ffmpeg sox libsox-fmt-all

# Verify installations
print("\n✅ Verifying critical installations...")
try:
    import torch
    import torchaudio
    import yt_dlp
    import librosa
    print(f"   PyTorch: {torch.__version__}")
    print(f"   Torchaudio: {torchaudio.__version__}")
    try:
        from yt_dlp.version import __version__ as yt_dlp_version
        print(f"   yt-dlp: {yt_dlp_version}")
    except ImportError:
        print("   yt-dlp: Installed (version not accessible)")
    print(f"   Librosa: {librosa.__version__}")
    print("✅ All dependencies installed successfully!")
except ImportError as e:
    print(f"❌ Installation error: {e}")
    print("Please restart runtime and try again.")
    raise  # Stop execution if dependencies are missing

print("=" * 70)
print("🎯 Ready to proceed with the main script!")

# ===============================================================================
# CELL 2: IMPORTS AND CONFIGURATION
# ===============================================================================

import os
import torch
import torchaudio
import gradio as gr
import yt_dlp
import shutil
import librosa
import numpy as np
import noisereduce as nr
from scipy import signal
from speechbrain.inference import EncoderClassifier
import logging
import uuid
import re
import time
import tempfile
from pathlib import Path
import requests
from pydub import AudioSegment

# Configure logging to reduce noise
logging.getLogger("speechbrain").setLevel(logging.WARNING)
logging.getLogger("yt_dlp").setLevel(logging.WARNING)

# ===============================================================================
# CONFIGURATION SETTINGS
# ===============================================================================

# Device configuration - use GPU if available
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️  Using device: {DEVICE}")

# Model configuration
MODEL_NAME = "speechbrain/lang-id-commonlanguage_ecapa"
ACCENT_MODEL = "Jzuluaga/accent-id-commonaccent_ecapa"

# Directory for temporary files
TEMP_AUDIO_DIR = "/content/temp_audio"
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)

# Global model variables
accent_classifier = None
language_classifier = None

print(f"📁 Temporary files will be stored in: {TEMP_AUDIO_DIR}")
print("⚙️  Configuration complete!")

# ===============================================================================
# CELL 3: MODEL LOADING FUNCTIONS
# ===============================================================================

def load_accent_classifier():
    """Load the accent classification model"""
    global accent_classifier
    if accent_classifier is None:
        print(f"🔄 Loading accent classification model: {ACCENT_MODEL}")
        try:
            model_dir = os.path.join("/content/models", "accent_classifier")
            os.makedirs(model_dir, exist_ok=True)

            accent_classifier = EncoderClassifier.from_hparams(
                source=ACCENT_MODEL,
                savedir=model_dir,
                run_opts={"device": DEVICE}
            )
            print("✅ Accent classification model loaded successfully!")
        except Exception as e:
            print(f"❌ Failed to load accent model: {e}")
            raise RuntimeError(f"Could not load accent classification model: {e}")
    return accent_classifier

def load_language_classifier():
    """Load the language identification model as backup"""
    global language_classifier
    if language_classifier is None:
        print(f"🔄 Loading language classification model: {MODEL_NAME}")
        try:
            model_dir = os.path.join("/content/models", "language_classifier")
            os.makedirs(model_dir, exist_ok=True)

            language_classifier = EncoderClassifier.from_hparams(
                source=MODEL_NAME,
                savedir=model_dir,
                run_opts={"device": DEVICE}
            )
            print("✅ Language classification model loaded successfully!")
        except Exception as e:
            print(f"⚠️  Language model loading failed: {e}")
            language_classifier = None
    return language_classifier

# ===============================================================================
# CELL 4: AUDIO PREPROCESSING FUNCTIONS
# ===============================================================================

def apply_bandpass_filter(audio_data, sr, lowcut=80, highcut=8000):
    """Apply bandpass filter to focus on speech frequencies"""
    try:
        nyquist = 0.5 * sr
        low = lowcut / nyquist
        high = highcut / nyquist

        if not (0 < low < 1.0) or not (0 < high < 1.0) or low >= high:
            print(f"⚠️  Invalid filter range, skipping filter")
            return audio_data

        b, a = signal.butter(4, [low, high], btype='band')
        filtered_audio = signal.lfilter(b, a, audio_data)
        return filtered_audio
    except Exception as e:
        print(f"⚠️  Bandpass filter failed: {e}")
        return audio_data

def preprocess_audio(input_path, output_path):
    """Comprehensive audio preprocessing pipeline"""
    print(f"🔧 Preprocessing audio: {input_path}")

    try:
        audio_data, sr = librosa.load(input_path, sr=16000, mono=True)

        if len(audio_data) < sr * 0.5:
            raise ValueError("Audio too short for reliable classification")

        if np.max(np.abs(audio_data)) < 1e-6:
            raise ValueError("Audio appears to be silent")

        audio_data, _ = librosa.effects.trim(audio_data, top_db=20)

        try:
            audio_data = nr.reduce_noise(
                y=audio_data,
                sr=sr,
                stationary=False,
                prop_decrease=0.8
            )
            print("✅ Noise reduction applied")
        except Exception as e:
            print(f"⚠️  Noise reduction failed: {e}")

        audio_data = apply_bandpass_filter(audio_data, sr)

        max_val = np.max(np.abs(audio_data))
        if max_val > 1e-6:
            audio_data = audio_data / max_val * 0.95

        max_length = sr * 60
        if len(audio_data) > max_length:
            audio_data = audio_data[:max_length]
            print("🔄 Audio trimmed to first 60 seconds")

        torchaudio.save(output_path, torch.tensor(audio_data).unsqueeze(0), sr)
        print(f"✅ Preprocessed audio saved: {output_path}")

        return output_path

    except Exception as e:
        print(f"❌ Preprocessing failed: {e}")
        if input_path != output_path:
            try:
                shutil.copy2(input_path, output_path)
                print("🔄 Using original audio as fallback")
                return output_path
            except:
                pass
        raise RuntimeError(f"Audio preprocessing failed: {e}")

# ===============================================================================
# CELL 5: ROBUST AUDIO EXTRACTION FUNCTIONS
# ===============================================================================

def is_valid_video_url(url):
    """Validate video URL format"""
    if not url or not isinstance(url, str):
        return False

    url = url.strip()
    if not url.startswith(('http://', 'https://')):
        return False

    patterns = [
        r'(youtube\.com|youtu\.be)',
        r'loom\.com',
        r'vimeo\.com',
        r'.*\.mp4($|\?)',
        r'.*\.mov($|\?)',
        r'.*\.avi($|\?)',
        r'drive\.google\.com',
        r'dropbox\.com'
    ]

    return any(re.search(pattern, url, re.IGNORECASE) for pattern in patterns)

def extract_audio_robust(video_url):
    """Enhanced audio extraction with multiple fallback strategies"""
    if not is_valid_video_url(video_url):
        raise ValueError("Invalid video URL format")

    unique_id = str(uuid.uuid4())[:8]
    raw_audio_path = os.path.join(TEMP_AUDIO_DIR, f"raw_{unique_id}.wav")
    processed_audio_path = os.path.join(TEMP_AUDIO_DIR, f"processed_{unique_id}.wav")

    print(f"🔄 Extracting audio from: {video_url}")

    ydl_opts_primary = {
        'format': 'bestaudio[ext=m4a]/bestaudio/best',
        'outtmpl': raw_audio_path.replace('.wav', '.%(ext)s'),
        'postprocessors': [{
            'key': 'FFmpegExtractAudio',
            'preferredcodec': 'wav',
            'preferredquality': '192',
        }],
        'quiet': True,
        'no_warnings': True,
        'extractaudio': True,
        'audioformat': 'wav',
        'audioquality': '192K',
        'prefer_insecure': True,
        'no_check_certificate': True,
        'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36',
        'extractor_retries': 5,
        'fragment_retries': 5,
        'retries': 5,
        'http_headers': {
            'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
            'Accept-Language': 'en-us,en;q=0.5',
            'Sec-Fetch-Mode': 'navigate',
        }
    }

    ydl_opts_fallback = {
        'format': 'worst[ext=mp4]/worst',
        'outtmpl': raw_audio_path.replace('.wav', '.%(ext)s'),
        'postprocessors': [{
            'key': 'FFmpegExtractAudio',
            'preferredcodec': 'wav',
        }],
        'quiet': True,
        'extract_flat': False,
        'no_warnings': True,
    }

    extraction_success = False

    try:
        with yt_dlp.YoutubeDL(ydl_opts_primary) as ydl:
            ydl.download([video_url])

        for ext in ['.wav', '.m4a', '.mp3', '.webm']:
            potential_file = raw_audio_path.replace('.wav', ext)
            if os.path.exists(potential_file) and os.path.getsize(potential_file) > 1000:
                if ext != '.wav':
                    audio = AudioSegment.from_file(potential_file)
                    audio.export(raw_audio_path, format="wav")
                    os.remove(potential_file)
                extraction_success = True
                break

    except Exception as e:
        print(f"⚠️  Primary extraction failed: {e}")

        try:
            with yt_dlp.YoutubeDL(ydl_opts_fallback) as ydl:
                ydl.download([video_url])

            for ext in ['.wav', '.mp4', '.webm']:
                potential_file = raw_audio_path.replace('.wav', ext)
                if os.path.exists(potential_file) and os.path.getsize(potential_file) > 1000:
                    if ext != '.wav':
                        audio = AudioSegment.from_file(potential_file)
                        audio.export(raw_audio_path, format="wav")
                        os.remove(potential_file)
                    extraction_success = True
                    break

        except Exception as e2:
            print(f"❌ Fallback extraction also failed: {e2}")

    if not extraction_success:
        temp_files = [f for f in os.listdir(TEMP_AUDIO_DIR) if unique_id in f]
        if temp_files:
            try:
                temp_file = os.path.join(TEMP_AUDIO_DIR, temp_files[0])
                audio = AudioSegment.from_file(temp_file)
                audio.export(raw_audio_path, format="wav")
                os.remove(temp_file)
                extraction_success = True
            except:
                pass

    if not extraction_success or not os.path.exists(raw_audio_path):
        raise RuntimeError(
            "Failed to extract audio. This could be due to:\n"
            "• YouTube bot detection (try a different video)\n"
            "• Private/restricted video content\n"
            "• Network connectivity issues\n"
            "• Unsupported video format"
        )

    print(f"✅ Audio extraction successful: {raw_audio_path}")

    try:
        processed_path = preprocess_audio(raw_audio_path, processed_audio_path)
        if os.path.exists(raw_audio_path):
            os.remove(raw_audio_path)
        return processed_path
    except Exception as e:
        print(f"⚠️  Preprocessing failed, using raw audio: {e}")
        return raw_audio_path

# ===============================================================================
# CELL 6: ACCENT CLASSIFICATION FUNCTION
# ===============================================================================

def classify_accent(audio_file_path):
    """Classify accent with comprehensive error handling"""
    classifier = load_accent_classifier()

    if not os.path.exists(audio_file_path):
        raise FileNotFoundError(f"Audio file not found: {audio_file_path}")

    if os.path.getsize(audio_file_path) < 1000:
        raise ValueError("Audio file too small or corrupted")

    print(f"🎯 Classifying accent for: {audio_file_path}")

    try:
        out_prob, score, _, text_lab = classifier.classify_file(audio_file_path)

        temperature = 0.7
        if hasattr(out_prob, 'squeeze'):
            probs = torch.softmax(out_prob.squeeze() / temperature, dim=0)
        else:
            probs = torch.softmax(torch.tensor(out_prob) / temperature, dim=0)

        probs_numpy = probs.cpu().numpy() if hasattr(probs, 'cpu') else np.array(probs)

        top_k = 5
        top_indices = np.argsort(probs_numpy)[-top_k:][::-1]

        if hasattr(classifier, 'hparams') and hasattr(classifier.hparams, 'label_encoder'):
            label_encoder = classifier.hparams.label_encoder
        else:
            label_encoder = type('obj', (object,), {
                'ind2lab': {i: f"Accent_{i}" for i in range(len(probs_numpy))}
            })()

        top_predictions = []
        for idx in top_indices:
            if idx < len(probs_numpy):
                confidence = float(probs_numpy[idx] * 100)
                if confidence > 0.1:
                    if hasattr(label_encoder, 'ind2lab') and idx in label_encoder.ind2lab:
                        accent_name = label_encoder.ind2lab[idx]
                    else:
                        accent_name = f"Unknown_Accent_{idx}"

                    accent_name = accent_name.replace('_', ' ').title()
                    top_predictions.append((accent_name, confidence))

        if not top_predictions:
            return "Unknown", 0.0, "No accent could be determined with confidence", []

        predicted_accent = top_predictions[0][0]
        confidence_score = top_predictions[0][1]

        if confidence_score < 40:
            warning = f"**Low Confidence ({confidence_score:.1f}%)**: Results may be unreliable due to audio quality, background noise, or accent not well-represented in training data."
        elif confidence_score < 65:
            warning = f"**Moderate Confidence ({confidence_score:.1f}%)**: Results are reasonably reliable but may benefit from clearer audio."
        else:
            warning = f"**High Confidence ({confidence_score:.1f}%)**: Results are likely accurate."

        return predicted_accent, confidence_score, warning, top_predictions

    except Exception as e:
        print(f"❌ Classification error: {e}")
        error_msg = "Classification failed. This might be due to incompatible audio format or model issues."
        return "Error", 0.0, error_msg, []

# ===============================================================================
# CELL 7: MAIN PROCESSING FUNCTION
# ===============================================================================

def process_video_url(video_url):
    """Main function to process video URL and return accent classification"""
    try:
        if not video_url or not video_url.strip():
            return "## ⚠️ Input Error\nPlease enter a valid video URL."

        video_url = video_url.strip()

        if not is_valid_video_url(video_url):
            return (
                "## ⚠️ Invalid URL\n\n"
                "Please provide a valid video URL from:\n"
                "- YouTube (youtube.com, youtu.be)\n"
                "- Loom (loom.com)\n"
                "- Vimeo (vimeo.com)\n"
                "- Direct video files (.mp4, .mov, .avi)\n"
                "- Google Drive or Dropbox shared videos"
            )

        print(f"\n{'='*50}")
        print(f"🎬 Processing Video URL: {video_url}")
        print(f"{'='*50}")

        audio_file = extract_audio_robust(video_url)
        print(f"✅ Audio ready for classification: {audio_file}")

        accent, confidence, warning, top_predictions = classify_accent(audio_file)

        result = format_results(video_url, accent, confidence, warning, top_predictions)

        print("✅ Processing complete!")
        return result

    except Exception as e:
        error_message = str(e)
        print(f"❌ Processing failed: {error_message}")

        if "youtube" in error_message.lower():
            error_guidance = (
                "**YouTube-specific issues:**\n"
                "- Try a different YouTube video\n"
                "- Use a shorter, more recent video\n"
                "- Ensure the video is public and not age-restricted\n"
                "- Consider using alternative platforms (Loom, Vimeo)"
            )
        elif "network" in error_message.lower() or "connection" in error_message.lower():
            error_guidance = (
                "**Network issues:**\n"
                "- Check your internet connection\n"
                "- Try again in a few minutes\n"
                "- Ensure the URL is accessible from your location"
            )
        else:
            error_guidance = (
                "**General troubleshooting:**\n"
                "- Verify the URL is correct and accessible\n"
                "- Try a different video URL\n"
                "- Ensure the video contains clear speech"
            )

        return (
            f"## ❌ Processing Error\n\n"
            f"**Error:** {error_message}\n\n"
            f"{error_guidance}\n\n"
            f"**Processed URL:** {video_url}"
        )
    finally:
        cleanup_temp_files()  # Always clean up temporary files

def format_results(video_url, accent, confidence, warning, top_predictions):
    """Format classification results for display"""
    accent_mapping = {
        'Us': 'American English (US)',
        'United States': 'American English (General)',
        'England': 'British English (England)',
        'Scotland': 'Scottish English',
        'Wales': 'Welsh English',
        'Ireland': 'Irish English',
        'Australia': 'Australian English',
        'New Zealand': 'New Zealand English',
        'Canada': 'Canadian English',
        'South Africa': 'South African English',
        'India': 'Indian English',
        'Philippines': 'Philippine English',
        'Singapore': 'Singaporean English',
        'Malaysia': 'Malaysian English',
        'Bermuda': 'Bermudian English',
        'Jamaica': 'Jamaican English',
        'Nigeria': 'Nigerian English',
    }

    display_accent = accent_mapping.get(accent, accent)

    result = f"## 🗣️ Accent Classification Results\n\n"
    result += f"**Detected Accent:** {display_accent}\n\n"
    result += f"**Confidence Score:** {confidence:.1f}%\n\n"
    result += f"**Analysis Notes:** {warning}\n\n"

    if top_predictions and len(top_predictions) > 1:
        result += "### 📊 Top Predictions:\n"
        for i, (pred_accent, pred_conf) in enumerate(top_predictions[:5], 1):
            mapped_accent = accent_mapping.get(pred_accent, pred_accent)
            result += f"{i}. **{mapped_accent}**: {pred_conf:.1f}%\n"
        result += "\n"

    result += (
        "### 📋 About This Analysis\n\n"
        "**Model Information:**\n"
        "- Uses advanced neural networks trained on diverse English accents\n"
        "- Analyzes the first 60 seconds of clear speech\n"
        "- Confidence scores indicate prediction reliability\n\n"
        "**Accuracy Notes:**\n"
        "- Higher confidence (>65%) generally indicates more reliable results\n"
        "- Background noise, music, or poor audio quality can affect accuracy\n"
        "- Some accent variations may not be perfectly distinguished\n"
        "- Results are intended for general guidance, not definitive assessment\n\n"
    )

    result += f"**Processed Video:** {video_url}\n"
    result += f"**Processing Time:** {time.strftime('%Y-%m-%d %H:%M:%S UTC')}"

    return result

# ===============================================================================
# CELL 8: CLEANUP AND UTILITY FUNCTIONS
# ===============================================================================

def cleanup_temp_files():
    """Clean up all temporary files"""
    try:
        for filename in os.listdir(TEMP_AUDIO_DIR):
            file_path = os.path.join(TEMP_AUDIO_DIR, filename)
            if os.path.isfile(file_path):
                os.remove(file_path)
                print(f"🧹 Cleaned up: {filename}")
    except Exception as e:
        print(f"⚠️  Cleanup warning: {e}")

def test_system():
    """Test system functionality"""
    print("🧪 Testing system components...")

    try:
        load_accent_classifier()
        print("✅ Accent classifier: OK")
    except Exception as e:
        print(f"❌ Accent classifier: {e}")
        return False

    try:
        import librosa
        import noisereduce
        print("✅ Audio processing libraries: OK")
    except Exception as e:
        print(f"❌ Audio libraries: {e}")
        return False

    try:
        import yt_dlp
        print("✅ yt-dlp: OK")
    except Exception as e:
        print(f"❌ yt-dlp: {e}")
        return False

    print("✅ System test passed!")
    return True

# ===============================================================================
# CELL 9: GRADIO INTERFACE
# ===============================================================================

def create_gradio_interface():
    """Create and configure the Gradio interface"""
    description = """
    ## 🎯 English Accent Identification Tool

    This tool analyzes video content to identify the speaker's English accent. Simply paste a video URL below and click "Analyze Accent".

    **Supported Platforms:**
    - 🔴 YouTube (youtube.com, youtu.be)
    - 🟣 Loom (loom.com)
    - 🔵 Vimeo (vimeo.com)
    - 📁 Direct video files (.mp4, .mov, .avi)
    - 📊 Google Drive/Dropbox shared videos

    **Features:**
    - Processes first 60 seconds of audio
    - Advanced noise reduction and audio enhancement
    - Confidence scoring for reliability assessment
    - Multiple accent prediction rankings

    **Tips for Best Results:**
    - Use videos with clear, uninterrupted speech
    - Avoid videos with heavy background music
    - Ensure videos are publicly accessible
    - Longer speech samples generally yield better results
    """

    examples = [
        ["https://www.youtube.com/watch?v=77ZRgyN9WsY"],  # Zoo (Australian English)
        ["https://www.youtube.com/watch?v=MU5L9rIOaqw"],  # Youtube (British English)
    ]

    interface = gr.Interface(
        fn=process_video_url,
        inputs=gr.Textbox(
            label="🔗 Video URL",
            placeholder="Paste your video URL here (YouTube, Loom, Vimeo, etc.)",
            lines=1,
            max_lines=3
        ),
        outputs=gr.Markdown(
            label="📊 Analysis Results",
            show_label=True
        ),
        title="🗣️ English Accent Classifier",
        description=description,
        examples=examples,
        cache_examples=False,
        allow_flagging="never",
        theme=gr.themes.Soft(),
    )

    return interface

# ===============================================================================
# CELL 10: MAIN EXECUTION AND LAUNCH
# ===============================================================================

if __name__ == "__main__":
    print("\n" + "="*70)
    print("🚀 INITIALIZING ACCENT CLASSIFICATION SYSTEM")
    print("="*70)

    print("📦 Pre-loading models (this may take a few minutes)...")
    try:
        load_accent_classifier()
        print("✅ Models loaded successfully!")
    except Exception as e:
        print(f"❌ Model loading failed: {e}")
        print("⚠️  The application cannot proceed without models.")
        raise

    if not test_system():
        print("❌ System test failed. Please check the installation.")
        raise

    cleanup_temp_files()

    print("\n🌐 Launching Gradio interface...")
    try:
        demo = create_gradio_interface()
        try:
            demo.launch(
                debug=True,
                share=True,
                server_name="0.0.0.0",
                server_port=7860,
                show_error=True,
                quiet=False
            )
        except Exception as e:
            print(f"⚠️ Public sharing failed: {e}")
            print("Falling back to local interface...")
            demo.launch(
                debug=True,
                share=False,
                server_name="0.0.0.0",
                server_port=7860,
                show_error=True,
                quiet=False
            )

    except Exception as e:
        print(f"❌ Failed to launch interface: {e}")
        print("Try running: !pip install --upgrade gradio")
        raise

print("\n" + "="*70)
print("📋 SETUP COMPLETE - READY TO USE!")
print("="*70)
print("📝 Instructions:")
print("1. Run all cells in order (1-10)")
print("2. Wait for 'SETUP COMPLETE' message")
print("3. Click the public URL link (or local URL if sharing failed) to access the interface")
print("4. Paste a video URL and click 'Analyze Accent'")
print("\n💡 Tip: Keep this Colab session running while using the interface")
print("="*70)