# From Dev to Production
## Call Center Quality Analysis with Acoustic Features
## This notebook demonstrates the complete ML lifecycle in Snowflake:
1. ü§ó Deploy HuggingFace model for feature generation
2. üî® Build end-to-end ML model
3. üìä Track experiments and metrics
4. üöÄ Deploy model for inference
5. ‚ö° Enable online feature store

In [None]:
!pip install torch transformers librosa soundfile torchaudio praat-parselmouth pyannote.audio speechbrain openai-whisper resampy imbalanced-learn

In [None]:
# Import required libraries
import snowflake.snowpark as snowpark
from snowflake.snowpark.functions import col, lit
from snowflake.ml.registry import Registry
from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
from snowflake.ml.modeling.xgboost import XGBClassifier
import pandas as pd
import numpy as np
import whisper
from transformers import pipeline
import librosa
import warnings
import resampy
from datetime import datetime
import soundfile as sf
from scipy import signal

session = get_active_session()

In [None]:
# Create a list of audio files from the stage
session.sql("""
    CREATE OR REPLACE TABLE BUILD25_DEV_TO_PRODUCTION.DATA.audio_file_list AS 
    SELECT 
        RELATIVE_PATH AS file_name,
        '@BUILD25_DEV_TO_PRODUCTION.DATA.audio_files/' || RELATIVE_PATH AS file_path,
        SIZE AS file_size_bytes,
        LAST_MODIFIED
    FROM DIRECTORY(@BUILD25_DEV_TO_PRODUCTION.DATA.audio_files)
    WHERE RELATIVE_PATH LIKE '%.wav' 
       OR RELATIVE_PATH LIKE '%.mp3'
       OR RELATIVE_PATH LIKE '%.flac'
""").collect()

file_count = session.sql("""
    SELECT COUNT(*) as file_count 
    FROM BUILD25_DEV_TO_PRODUCTION.DATA.audio_file_list
""").collect()[0]['FILE_COUNT']

print(f"‚úì Found {file_count} audio files in stage")

# Show sample files
print("\nSample audio files:")
session.table("BUILD25_DEV_TO_PRODUCTION.DATA.audio_file_list").limit(5).show()

## Step 1: Deploy HuggingFace Model
 
**‚ú® NEW: HuggingFace Model Deployment UI**
 
 We'll deploy an acoustic feature extractor that analyzes call audio to extract:
 - Speech patterns (rate, pauses, rhythm)
 - Voice characteristics (pitch, tone, energy)
 - Interaction dynamics (interruptions, turn-taking)
 - Emotional indicators (stress, emotion trajectory)
 
### In the HuggingFace Deployment UI:
1. Browse Model Hub ‚Üí Select `tabularisai/multilingual-sentiment-analysis`
2. Configure input/output
3. Click 'Deploy'

In [None]:
"""
Snowflake Audio Feature Extraction Script
Processes all audio files from a Snowflake stage and saves features to a table.
Uses Whisper for transcription and Snowflake Model Registry for sentiment analysis.
"""

import snowflake.snowpark as snowpark
from snowflake.ml.registry import Registry
from snowflake.snowpark.functions import col
import pandas as pd
import numpy as np
import whisper
from transformers import pipeline
import librosa
import warnings
import os
import tempfile
import soundfile as sf
from scipy import signal
 
warnings.filterwarnings('ignore')
# Suppress specific librosa warnings
import logging
logging.getLogger('librosa').setLevel(logging.ERROR)


# ============================================================================
# ACOUSTIC FEATURE EXTRACTOR (Using Whisper + Snowflake Model Registry)
# ============================================================================

class HuggingFaceAcousticExtractor:
    """Extract acoustic features using Whisper + HuggingFace models + Librosa + Snowflake Registry."""
    
    def __init__(self, whisper_model_size="base", session=None):
        print("Loading models...")
        
        # Store session for model registry access
        self.session = session
        
        # Whisper for transcription
        print(f"  Loading Whisper ({whisper_model_size})...")
        self.whisper_model = whisper.load_model(whisper_model_size)
        
        # Emotion recognition
        print("  Loading emotion model...")
        self.emotion_pipeline = pipeline(
            "audio-classification",
            model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
        )
        
        # Named Entity Recognition
        print("  Loading NER model (BERT)...")
        self.ner_pipeline = pipeline(
            "ner",
            model="dslim/bert-base-NER",
            aggregation_strategy="simple"
        )
        
        # Multilingual Sentiment Analysis - Using Snowflake Model Registry
        print("  Loading sentiment model from Snowflake Model Registry...")
        if self.session:
            try:
                
                
                # Get the registry
                reg = Registry(session=self.session)
                
                # Get the model - adjust database/schema as needed
                self.sentiment_model = reg.get_model(
                    "MULTILINGUAL_SENTIMENT_ANALYSIS"
                ).default
                
                print("  ‚úì Sentiment model loaded from Snowflake Model Registry")
                self.use_registry_sentiment = True
            except Exception as e:
                print(f"  Warning: Could not load from registry ({e}), falling back to HuggingFace")
                self.sentiment_pipeline = pipeline(
                    "sentiment-analysis",
                    model="tabularisai/multilingual-sentiment-analysis"
                )
                self.use_registry_sentiment = False
        else:
            print("  Warning: No session provided, using HuggingFace model")
            self.sentiment_pipeline = pipeline(
                "sentiment-analysis",
                model="tabularisai/multilingual-sentiment-analysis"
            )
            self.use_registry_sentiment = False
        
        print("‚úì Models loaded successfully")
    
    def analyze_sentiment_with_registry(self, text: str):
        """Analyze sentiment using Snowflake Model Registry."""
        try:
            # Truncate text to reasonable length
            text_truncated = text[:512]
            
            # Create a pandas DataFrame with the text
            import pandas as pd
            input_df = pd.DataFrame({'text': [text_truncated]})
            
            # Convert to Snowpark DataFrame
            input_sp = self.session.create_dataframe(input_df)
            
            # Run prediction using the model
            result_sp = self.sentiment_model.run(input_sp, function_name="predict")
            
            # Convert back to pandas to extract results
            result_pd = result_sp.to_pandas()
            
            # Extract label and score - adjust column names based on your model's output
            # Common output formats: 'label', 'LABEL', 'prediction', etc.
            if 'label' in result_pd.columns:
                sentiment_label = result_pd['label'].iloc[0].lower()
            elif 'LABEL' in result_pd.columns:
                sentiment_label = result_pd['LABEL'].iloc[0].lower()
            else:
                # Print available columns to debug
                print(f"      Available columns: {result_pd.columns.tolist()}")
                sentiment_label = 'neutral'
            
            if 'score' in result_pd.columns:
                sentiment_score = float(result_pd['score'].iloc[0])
            elif 'SCORE' in result_pd.columns:
                sentiment_score = float(result_pd['SCORE'].iloc[0])
            else:
                sentiment_score = 0.5
            
            return sentiment_label, sentiment_score
                
        except Exception as e:
            print(f"      Registry sentiment analysis failed: {e}")
            import traceback
            traceback.print_exc()
            return 'neutral', 0.5
    
    def extract_all_features(self, audio_path: str) -> dict:
        """Extract all acoustic features from audio file."""
        try:
            # Verify file exists and is readable
            if not os.path.exists(audio_path):
                print(f"      Error: File does not exist: {audio_path}")
                return None
            
            file_size = os.path.getsize(audio_path)
            if file_size == 0:
                print(f"      Error: File is empty (0 bytes)")
                return None
            
            print(f"      Loading audio file ({file_size} bytes)...")
            
            # Load audio - use soundfile + scipy for resampling (no resampy dependency)
            try:
                # Load with soundfile
                audio, sr = sf.read(audio_path)
                
                # Convert stereo to mono if needed
                if len(audio.shape) > 1:
                    audio = np.mean(audio, axis=1)
                
                # Resample to 16kHz using scipy if needed
                if sr != 16000:
                    print(f"      Resampling from {sr}Hz to 16000Hz...")
                    # Calculate number of samples in output
                    num_samples = int(len(audio) * 16000 / sr)
                    # Use scipy's resample function
                    audio = signal.resample(audio, num_samples)
                    sr = 16000
                    
            except Exception as audio_error:
                print(f"      Error: Could not load audio file: {audio_error}")
                return None
            
            duration_sec = len(audio) / sr
            duration_minutes = duration_sec / 60
            
            print(f"      Audio loaded: {duration_sec:.2f}s, {sr}Hz")
            
            # Get transcription for word count using Whisper
            word_count = 0
            transcript_text = ""
            try:
                print(f"      Transcribing with Whisper...")
                result = self.whisper_model.transcribe(audio_path, fp16=False)
                transcript_text = result['text']
                word_count = len(transcript_text.split())
                print(f"      Transcribed: {word_count} words")
            except Exception as e:
                print(f"      Transcription failed: {e}")
            
            # Calculate RMS energy for speech detection
            frame_length = 2048
            hop_length = 512
            rms = librosa.feature.rms(y=audio, frame_length=frame_length, hop_length=hop_length)[0]
            threshold = np.mean(rms) * 0.3
            speech_frames = rms > threshold
            
            # Speech rate features
            speaking_rate_wpm = float(word_count / duration_minutes if duration_minutes > 0 else 0)
            speech_energy = rms[speech_frames]
            speech_rate_variability = float(
                np.std(speech_energy) / np.mean(speech_energy) 
                if len(speech_energy) > 0 else 0
            )
            
            # Pause features
            frame_duration = hop_length / sr
            transitions = np.diff(speech_frames.astype(int))
            pause_count = int(np.sum(transitions == -1))
            pause_frequency = float(pause_count / duration_minutes if duration_minutes > 0 else 0)
            
            # Calculate pause durations
            non_speech_segments = []
            in_pause = False
            pause_start = 0
            for i, is_speech in enumerate(speech_frames):
                if not is_speech and not in_pause:
                    in_pause = True
                    pause_start = i
                elif is_speech and in_pause:
                    in_pause = False
                    non_speech_segments.append(i - pause_start)
            
            avg_pause_duration = float(
                np.mean(non_speech_segments) * frame_duration 
                if non_speech_segments else 0
            )
            
            # Pitch features
            pitches, magnitudes = librosa.piptrack(y=audio, sr=sr, fmin=75, fmax=300)
            pitch_values = []
            for t in range(pitches.shape[1]):
                index = magnitudes[:, t].argmax()
                pitch = pitches[index, t]
                if pitch > 0:
                    pitch_values.append(pitch)
            
            if len(pitch_values) > 10:
                avg_pitch_hz = float(np.mean(pitch_values))
                pitch_variance = float(np.var(pitch_values))
                pitch_range_hz = float(np.max(pitch_values) - np.min(pitch_values))
            else:
                avg_pitch_hz = 0.0
                pitch_variance = 0.0
                pitch_range_hz = 0.0
            
            # Energy features
            energy_mean = float(np.mean(rms))
            energy_variance = float(np.var(rms))
            rms_db = librosa.amplitude_to_db(rms)
            dynamic_range_db = float(np.max(rms_db) - np.min(rms_db))
            
            # Spectral features
            spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr)[0]
            spectral_centroid_mean = float(np.mean(spectral_centroid))
            zcr = librosa.feature.zero_crossing_rate(audio)[0]
            zcr_mean = float(np.mean(zcr))
            
            # Voice quality
            jitter = float(np.std(np.diff(zcr)) * 0.01)
            shimmer = float(energy_variance * 0.1)
            harmonics_to_noise_ratio = float(
                min(energy_mean / (energy_variance + 1e-10) * 10, 20.0)
            )
            
            # Silence features
            silence_ratio = float(1.0 - np.mean(speech_frames))
            speech_ratio = float(np.mean(speech_frames))
            speech_to_silence_ratio = float(
                speech_ratio / silence_ratio if silence_ratio > 0.01 else 10.0
            )
            
            # Interaction features
            energy_diff = np.abs(np.diff(rms))
            threshold_interruption = np.mean(energy_diff) + 1.5 * np.std(energy_diff)
            interruptions = int(np.sum(energy_diff > threshold_interruption))
            interruption_count = int(min(interruptions / 20, 15))
            
            high_energy = rms > np.percentile(rms, 60)
            transitions_energy = int(np.sum(np.abs(np.diff(high_energy.astype(int)))))
            turn_taking_rate = float(
                transitions_energy / 2 / duration_minutes if duration_minutes > 0 else 0
            )
            agent_talk_ratio = float(np.mean(high_energy))
            
            # Turn duration
            turn_durations = []
            current_turn = 0
            for i in range(1, len(high_energy)):
                if high_energy[i] == high_energy[i-1]:
                    current_turn += 1
                else:
                    if current_turn > 0:
                        turn_durations.append(current_turn)
                    current_turn = 0
            
            avg_turn_duration = float(
                np.mean(turn_durations) * frame_duration if turn_durations else 0
            )
            
            # Get emotions using HuggingFace model
            try:
                print(f"      Analyzing emotions...")
                emotions = self.emotion_pipeline(audio_path)
                negative_emotions = ['angry', 'disgust', 'fearful', 'sad']
                negative_score = sum(
                    e['score'] for e in emotions 
                    if e['label'] in negative_emotions
                )
                stress_score = sum(
                    e['score'] for e in emotions 
                    if e['label'] in ['angry', 'fearful']
                )
                dominant_emotion = max(emotions, key=lambda x: x['score'])['label']
                print(f"      Emotion: {dominant_emotion} (stress: {stress_score:.3f})")
            except Exception as e:
                print(f"      Emotion detection failed: {e}")
                negative_score = 0.0
                stress_score = 0.0
                dominant_emotion = 'unknown'
            
            # Emotion volatility from energy patterns
            num_windows = 6
            window_size = len(rms) // num_windows
            emotion_scores = []
            for i in range(num_windows):
                start = i * window_size
                end = min(start + window_size, len(rms))
                if end - start >= 10:
                    window_rms = rms[start:end]
                    window_spectral = spectral_centroid[start:end]
                    energy_var = np.var(window_rms)
                    spectral_var = np.var(window_spectral)
                    score = (energy_var / (np.mean(window_rms) + 1e-10) + 
                            spectral_var / (np.mean(window_spectral) + 1e-10)) / 2
                    emotion_scores.append(min(score * 0.1, 1.0))
            
            emotion_volatility = float(
                np.std(emotion_scores) if len(emotion_scores) > 1 else 0
            )
            
            # Extract named entities from transcript using BERT NER
            entities_person = []
            entities_org = []
            entities_loc = []
            entities_misc = []
            entity_count = 0
            
            # Sentiment analysis from transcript
            sentiment_label = 'neutral'
            sentiment_score = 0.5
            
            if transcript_text and len(transcript_text.strip()) > 0:
                try:
                    print(f"      Extracting named entities...")
                    ner_results = self.ner_pipeline(transcript_text)
                    
                    for entity in ner_results:
                        entity_type = entity['entity_group']
                        entity_text = entity['word']
                        
                        if entity_type == 'PER':
                            entities_person.append(entity_text)
                        elif entity_type == 'ORG':
                            entities_org.append(entity_text)
                        elif entity_type == 'LOC':
                            entities_loc.append(entity_text)
                        elif entity_type == 'MISC':
                            entities_misc.append(entity_text)
                    
                    entity_count = len(ner_results)
                    print(f"      Found {entity_count} entities: {len(entities_person)} persons, {len(entities_org)} orgs, {len(entities_loc)} locations")
                    
                except Exception as e:
                    print(f"      NER failed: {e}")
                
                # Analyze sentiment from transcript - use Registry or HuggingFace
                try:
                    print(f"      Analyzing transcript sentiment...")
                    if self.use_registry_sentiment:
                        sentiment_label, sentiment_score = self.analyze_sentiment_with_registry(transcript_text)
                        print(f"      Sentiment (Registry): {sentiment_label} (confidence: {sentiment_score:.3f})")
                    else:
                        sentiment_result = self.sentiment_pipeline(transcript_text[:512])
                        sentiment_label = sentiment_result[0]['label'].lower()
                        sentiment_score = sentiment_result[0]['score']
                        print(f"      Sentiment (HuggingFace): {sentiment_label} (confidence: {sentiment_score:.3f})")
                except Exception as e:
                    print(f"      Sentiment analysis failed: {e}")
            
            print(f"      ‚úì All features extracted successfully")
            
            return {
                'speaking_rate_wpm': speaking_rate_wpm,
                'speech_rate_variability': speech_rate_variability,
                'avg_pause_duration_sec': avg_pause_duration,
                'pause_frequency_per_min': pause_frequency,
                'avg_pitch_hz': avg_pitch_hz,
                'pitch_variance': pitch_variance,
                'pitch_range_hz': pitch_range_hz,
                'energy_mean': energy_mean,
                'energy_variance': energy_variance,
                'dynamic_range_db': dynamic_range_db,
                'spectral_centroid': spectral_centroid_mean,
                'harmonics_to_noise_ratio': harmonics_to_noise_ratio,
                'jitter': jitter,
                'shimmer': shimmer,
                'zero_crossing_rate': zcr_mean,
                'silence_ratio': silence_ratio,
                'speech_to_silence_ratio': speech_to_silence_ratio,
                'interruption_count': interruption_count,
                'agent_talk_ratio': agent_talk_ratio,
                'turn_taking_rate': turn_taking_rate,
                'avg_turn_duration_sec': avg_turn_duration,
                'avg_emotion_score': float(negative_score),
                'emotion_volatility': emotion_volatility,
                'stress_indicators': float(stress_score),
                'dominant_emotion': dominant_emotion,
                'transcript': transcript_text,
                'word_count': word_count,
                'entity_count': entity_count,
                'entities_person': ','.join(entities_person) if entities_person else '',
                'entities_org': ','.join(entities_org) if entities_org else '',
                'entities_loc': ','.join(entities_loc) if entities_loc else '',
                'entities_misc': ','.join(entities_misc) if entities_misc else '',
                'sentiment_label': sentiment_label,
                'sentiment_score': float(sentiment_score)
            }
        except Exception as e:
            print(f"      ‚úó Error processing audio: {e}")
            import traceback
            traceback.print_exc()
            return None


# ============================================================================
# SNOWFLAKE PROCESSING SCRIPT
# ============================================================================

def process_stage_audio_files(session: snowpark.Session):
    """
    Process all audio files from Snowflake stage and save features to table.
    """
    
    print("="*80)
    print("SNOWFLAKE AUDIO FEATURE EXTRACTION")
    print("="*80)
    
    # Step 1: Enable directory on stage
    print("\n1. Enabling directory on stage...")
    try:
        session.sql("""
            ALTER STAGE BUILD25_DEV_TO_PRODUCTION.DATA.audio_files 
            SET DIRECTORY = (ENABLE = TRUE)
        """).collect()
        print("   ‚úì Directory enabled")
    except Exception as e:
        print(f"   Note: Directory may already be enabled ({e})")
    
    # Step 2: Create list of audio files
    print("\n2. Getting list of audio files from stage...")
    session.sql("""
        CREATE OR REPLACE TABLE BUILD25_DEV_TO_PRODUCTION.DATA.audio_file_list AS 
        SELECT 
            RELATIVE_PATH AS file_name,
            '@BUILD25_DEV_TO_PRODUCTION.DATA.audio_files/' || RELATIVE_PATH AS file_path,
            SIZE AS file_size_bytes,
            LAST_MODIFIED
        FROM DIRECTORY(@BUILD25_DEV_TO_PRODUCTION.DATA.audio_files)
        WHERE RELATIVE_PATH LIKE '%.wav' 
           OR RELATIVE_PATH LIKE '%.mp3'
           OR RELATIVE_PATH LIKE '%.flac'
    """).collect()
    
    file_count = session.sql("""
        SELECT COUNT(*) as file_count 
        FROM BUILD25_DEV_TO_PRODUCTION.DATA.audio_file_list
    """).collect()[0]['FILE_COUNT']
    
    print(f"   ‚úì Found {file_count} audio files")
    
    # Show sample
    print("\n   Sample files:")
    session.table("BUILD25_DEV_TO_PRODUCTION.DATA.audio_file_list").limit(5).show()
    
    # Step 3: Get files as pandas DataFrame
    print("\n3. Loading file list...")
    audio_files_df = session.table("BUILD25_DEV_TO_PRODUCTION.DATA.audio_file_list").to_pandas()
    print(f"   ‚úì Loaded {len(audio_files_df)} files")
    
    # Step 4: Initialize feature extractor (pass session for model registry)
    print("\n4. Initializing Whisper and models...")
    extractor = HuggingFaceAcousticExtractor(whisper_model_size="base", session=session)
    
    # Step 5: Create temporary directory for downloaded files
    print("\n5. Setting up temporary directory for audio files...")
    temp_dir = tempfile.mkdtemp(prefix="snowflake_audio_")
    print(f"   ‚úì Temporary directory: {temp_dir}")
    
    # Step 6: Process all audio files
    print(f"\n6. Processing {len(audio_files_df)} audio files...")
    print("   This may take a while depending on file count...\n")
    
    features_list = []
    errors = []
    
    for idx, row in audio_files_df.iterrows():
        stage_path = row['FILE_PATH']
        file_name = row['FILE_NAME']
        
        # Extract call metadata from filename
        parts = file_name.replace('.wav', '').replace('.mp3', '').replace('.flac', '').split('_')
        call_id = parts[0] if len(parts) > 0 else f'call_{idx}'
        agent_id = parts[1] if len(parts) > 1 else f'agent_{idx % 10}'
        
        print(f"   Processing [{idx+1}/{len(audio_files_df)}]: {file_name}")
        
        try:
            # Download file from stage to local temp directory
            local_path = os.path.join(temp_dir, file_name)
            print(f"      Downloading from stage: {stage_path}")
            
            # Use GET command to download file
            get_result = session.file.get(stage_path, temp_dir)
            print(f"      Download result: {get_result}")
            
            # Verify file exists and check size
            if not os.path.exists(local_path):
                # Try alternate path (sometimes GET creates subdirectories)
                alt_path = os.path.join(temp_dir, os.path.basename(file_name))
                if os.path.exists(alt_path):
                    local_path = alt_path
                else:
                    print(f"      ‚úó Error: File not found after download")
                    print(f"         Expected: {local_path}")
                    print(f"         Directory contents: {os.listdir(temp_dir)}")
                    errors.append({'file_name': file_name, 'error': 'Download failed - file not found'})
                    continue
            
            file_size_local = os.path.getsize(local_path)
            print(f"      ‚úì File downloaded: {file_size_local:,} bytes")
            
            # Extract features
            features = extractor.extract_all_features(local_path)
            
            if features:
                # Add metadata
                features['call_id'] = call_id
                features['agent_id'] = agent_id
                features['file_name'] = file_name
                features['file_path'] = stage_path
                features['file_size_bytes'] = int(row['FILE_SIZE_BYTES'])
                features['processed_at'] = pd.Timestamp.now()
                
                features_list.append(features)
                print(f"      ‚úì Features extracted successfully")
            else:
                errors.append({'file_name': file_name, 'error': 'Feature extraction failed'})
            
            # Clean up local file to save space
            try:
                os.remove(local_path)
            except:
                pass
                
        except Exception as e:
            print(f"      Error: {e}")
            errors.append({'file_name': file_name, 'error': str(e)})
        
        # Progress update every 10 files
        if (idx + 1) % 10 == 0:
            print(f"\n   Progress: {idx+1}/{len(audio_files_df)} files processed\n")
    
    print(f"\n   ‚úì Successfully processed {len(features_list)} files")
    if errors:
        print(f"   ‚ö† {len(errors)} files failed")
        print("\n   Failed files:")
        for err in errors[:5]:  # Show first 5 errors
            print(f"      - {err['file_name']}: {err['error']}")
    
    # Clean up temp directory
    print(f"\n7. Cleaning up temporary directory...")
    try:
        import shutil
        shutil.rmtree(temp_dir)
        print(f"   ‚úì Temporary directory removed")
    except Exception as e:
        print(f"   Warning: Could not remove temp directory: {e}")
    
    # Step 7: Convert to DataFrame
    print("\n8. Creating features DataFrame...")
    
    if len(features_list) == 0:
        print("   ‚úó ERROR: No features were extracted successfully!")
        print("   Please check the error messages above.")
        if errors:
            print(f"\n   All {len(errors)} files failed. Sample errors:")
            for err in errors[:10]:
                print(f"      - {err['file_name']}: {err['error']}")
        return None
    
    features_pd = pd.DataFrame(features_list)
    print(f"   ‚úì Created DataFrame with {len(features_pd)} rows and {len(features_pd.columns)} columns")
    
    # Show sample
    print("\n   Sample features:")
    print(features_pd[['call_id', 'agent_id', 'speaking_rate_wpm', 'interruption_count', 'stress_indicators']].head())
    
    # Step 8: Save to Snowflake table
    print("\n9. Saving features to Snowflake table...")
    features_sp = session.create_dataframe(features_pd)
    features_sp.write.mode('overwrite').save_as_table('BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features')
    
    print("   ‚úì Features saved to: BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features")
    
    # Step 9: Verify saved data
    print("\n10. Verifying saved data...")
    saved_count = session.sql("""
        SELECT COUNT(*) as row_count 
        FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features
    """).collect()[0]['ROW_COUNT']
    
    print(f"   ‚úì Verified {saved_count} rows in table")
    
    # Show sample from table
    print("\n   Sample from saved table:")
    session.table('BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features').limit(5).show()
    
    # Step 10: Summary statistics
    print("\n11. Summary Statistics:")
    
    # First, check what columns exist
    print("   Checking column names...")
    columns_df = session.sql("""
        SELECT * 
        FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features 
        LIMIT 1
    """).collect()
    
    if columns_df:
        print(f"   Available columns: {list(columns_df[0].asDict().keys())[:10]}...")
    
    # Use double quotes to preserve case sensitivity
    try:
        summary = session.sql("""
            SELECT 
                COUNT(*) as total_calls,
                COUNT(DISTINCT "agent_id") as unique_agents,
                ROUND(AVG("speaking_rate_wpm"), 2) as avg_speaking_rate,
                ROUND(AVG("interruption_count"), 2) as avg_interruptions,
                ROUND(AVG("stress_indicators"), 3) as avg_stress,
                ROUND(AVG("agent_talk_ratio"), 3) as avg_agent_talk_ratio
            FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features
        """).collect()[0]
        
        print(f"   Total calls processed: {summary['TOTAL_CALLS']}")
        print(f"   Unique agents: {summary['UNIQUE_AGENTS']}")
        print(f"   Avg speaking rate: {summary['AVG_SPEAKING_RATE']} WPM")
        print(f"   Avg interruptions: {summary['AVG_INTERRUPTIONS']}")
        print(f"   Avg stress level: {summary['AVG_STRESS']}")
        print(f"   Avg agent talk ratio: {summary['AVG_AGENT_TALK_RATIO']}")
    except Exception as e:
        print(f"   Note: Could not compute summary statistics: {e}")
        print("   You can query the table directly to see the data.")
    
    print("\n" + "="*80)
    print("PROCESSING COMPLETE!")
    print("="*80)
    print(f"\n‚úì Processed {len(features_list)} audio files")
    print(f"‚úì Extracted {len(features_pd.columns)} features per file")
    print(f"‚úì Saved to table: BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features")
    print("\nReady for model training!")
    
    return features_sp

## Step 2: Build End-to-End Model
 
Now we'll use the deployed HuggingFace model to:
1. Generate acoustic features from call audio
2. Train an ML model to predict call quality

### 2A. Generate Features from Audio

In [None]:
print(f"\n‚úì Connected to Snowflake")
print(f"  Database: {session.get_current_database()}")
print(f"  Schema: {session.get_current_schema()}")
print(f"  Warehouse: {session.get_current_warehouse()}")

# Process all audio files
features_df = process_stage_audio_files(session)

print("\n" + "="*80)
print("You can now use the features for model training:")
print("  Table: BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features")
print("="*80)

In [None]:
SELECT 
    "call_id",
    "transcript",
    "dominant_emotion",
    "stress_indicators",
    "interruption_count"
FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features
WHERE "transcript" IS NOT NULL 
  AND LENGTH("transcript") > 0
LIMIT 5;

In [None]:
CREATE OR REPLACE TABLE BUILD25_DEV_TO_PRODUCTION.DATA.call_outcomes AS
WITH analyzed AS (
    SELECT 
        "call_id",
        "agent_id",
        AI_COMPLETE(
            'claude-sonnet-4-5',
            CONCAT(
                'You are an expert call center quality analyst. Analyze the following customer service call and respond with ONLY a JSON object.\n\n',
                'TRANSCRIPT:\n', "transcript", '\n\n',
                'ACOUSTIC CONTEXT:\n',
                '- Dominant emotion: ', "dominant_emotion", '\n',
                '- Stress level: ', "stress_indicators", '\n',
                '- Interruptions: ', "interruption_count", '\n\n',
                'Based on the transcript and context, return ONLY this JSON (no markdown, no other text):\n',
                '{"call_resolved": 0 or 1, "customer_satisfaction_score": 1-5, "resolution_confidence": 0.0-1.0, "reasoning": "explanation"}\n\n',
                'SCORING:\n',
                '- call_resolved: 1 if issue fully resolved, 0 otherwise\n',
                '- customer_satisfaction_score: 5=very satisfied, 4=satisfied, 3=neutral, 2=dissatisfied, 1=very dissatisfied\n',
                '- resolution_confidence: 0.0 to 1.0\n',
                '- reasoning: brief explanation\n\n',
                'Return ONLY the JSON object.'
            )
        ) as claude_response
    FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features
    WHERE "transcript" IS NOT NULL 
      AND LENGTH("transcript") > 0
),
cleaned AS (
    SELECT 
        "call_id",
        "agent_id",
        claude_response,
        TRIM(REPLACE(REPLACE(claude_response, '```json', ''), '```', '')) as cleaned_response
    FROM analyzed
),
parsed AS (
    SELECT 
        "call_id",
        "agent_id",
        claude_response,
        cleaned_response,
        TRY_PARSE_JSON(cleaned_response) as parsed_json
    FROM cleaned
)
SELECT 
    "call_id",
    "agent_id",
    COALESCE(parsed_json:call_resolved::INTEGER, 0) as call_resolved,
    COALESCE(parsed_json:customer_satisfaction_score::INTEGER, 3) as customer_satisfaction_score,
    COALESCE(parsed_json:resolution_confidence::FLOAT, 0.5) as resolution_confidence,
    parsed_json:reasoning::STRING as reasoning,
    claude_response as raw_response,
    CURRENT_TIMESTAMP() as analyzed_at
FROM parsed
WHERE parsed_json IS NOT NULL;


In [None]:
SELECT 
    COUNT(*) as total_calls,
    SUM(call_resolved) as resolved_calls,
    ROUND(AVG(call_resolved) * 100, 1) as resolution_rate_pct,
    ROUND(AVG(customer_satisfaction_score), 2) as avg_satisfaction,
    ROUND(AVG(resolution_confidence), 3) as avg_confidence,
    COUNT(CASE WHEN call_resolved = 1 AND customer_satisfaction_score >= 4 THEN 1 END) as high_quality_calls,
    ROUND(COUNT(CASE WHEN call_resolved = 1 AND customer_satisfaction_score >= 4 THEN 1 END) * 100.0 / COUNT(*), 1) as high_quality_pct
FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_outcomes;

In [None]:
"""
Snowflake ML Model Training for Call Quality Prediction
Optimized for F1 Score with Threshold Tuning
"""

import snowflake.snowpark as snowpark
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
from snowflake.snowpark.functions import col, when, length, regexp_count, random
from datetime import datetime
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve
import numpy as np
import pandas as pd

# ============================================================================
# DATA PREPARATION
# ============================================================================

print("="*80)
print("CALL QUALITY PREDICTION - MODEL TRAINING")
print("="*80)

print("\n1. Creating training dataset...")
training_df = session.sql("""
    SELECT 
        f.*,
        o.call_resolved,
        o.customer_satisfaction_score,
        CASE 
            WHEN o.call_resolved = 1 
                 AND o.customer_satisfaction_score >= 4
            THEN 1 
            ELSE 0 
        END as "high_quality_call"
    FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features f
    JOIN BUILD25_DEV_TO_PRODUCTION.DATA.call_outcomes o 
        ON f."call_id" = o."call_id"
""")

print(f"   ‚úì Training dataset size: {training_df.count()} calls")

# ============================================================================
# FEATURE DEFINITION
# ============================================================================

print("\n2. Defining feature sets...")

acoustic_features = [
    '"speaking_rate_wpm"',
    '"speech_rate_variability"',
    '"avg_pause_duration_sec"',
    '"pause_frequency_per_min"',
    '"avg_pitch_hz"',
    '"pitch_variance"',
    '"pitch_range_hz"',
    '"energy_mean"',
    '"energy_variance"',
    '"dynamic_range_db"',
    '"spectral_centroid"',
    '"harmonics_to_noise_ratio"',
    '"jitter"',
    '"shimmer"',
    '"zero_crossing_rate"',
    '"silence_ratio"',
    '"speech_to_silence_ratio"'
]

interaction_features = [
    '"interruption_count"',
    '"agent_talk_ratio"',
    '"turn_taking_rate"',
    '"avg_turn_duration_sec"'
]

emotion_features = [
    '"avg_emotion_score"',
    '"emotion_volatility"',
    '"stress_indicators"'
]

sentiment_features = [
    '"sentiment_score"'
]

ner_features = [
    '"word_count"',
    '"entity_count"'
]

all_features = acoustic_features + interaction_features + emotion_features + sentiment_features + ner_features

print(f"   ‚úì Total features: {len(all_features)}")
print(f"      - Acoustic: {len(acoustic_features)}")
print(f"      - Interaction: {len(interaction_features)}")
print(f"      - Emotion: {len(emotion_features)}")
print(f"      - Sentiment: {len(sentiment_features)}")
print(f"      - NER/Content: {len(ner_features)}")

# ============================================================================
# DATA SPLITTING
# ============================================================================

print("\n3. Splitting data into train/test/holdout sets...")

# First, separate 5 calls for holdout set (not used in training or testing)
training_df = training_df.with_column("random_split", random())
holdout_df = training_df.filter(col("random_split") <= 0.05).limit(5)
remaining_df = training_df.filter(col("random_split") > 0.05)

# Then split remaining into train/test
remaining_df = remaining_df.with_column("random_split2", random())
train_df = remaining_df.filter(col("random_split2") <= 0.8).drop("random_split", "random_split2")
test_df = remaining_df.filter(col("random_split2") > 0.8).drop("random_split", "random_split2")
holdout_df = holdout_df.drop("random_split")

print(f"   ‚úì Train set: {train_df.count()} calls")
print(f"   ‚úì Test set: {test_df.count()} calls")
print(f"   ‚úì Holdout set: {holdout_df.count()} calls (for future validation)")

# Save holdout call IDs for reference
holdout_ids = holdout_df.select('"call_id"').to_pandas()['call_id'].tolist()
print(f"\n   üìã Holdout Call IDs (saved for later):")
for i, call_id in enumerate(holdout_ids, 1):
    print(f"      {i}. {call_id}")

train_positive = train_df.filter(col('"high_quality_call"') == 1).count()
train_total = train_df.count()
train_negative = train_total - train_positive

print(f"\n   üìä Class Distribution:")
print(f"      Train positive: {train_positive} ({train_positive/train_total:.2%})")
print(f"      Train negative: {train_negative} ({train_negative/train_total:.2%})")

if train_positive > 0:
    scale_pos_weight = train_negative / train_positive
    print(f"      Scale pos weight: {scale_pos_weight:.2f}")
else:
    scale_pos_weight = 1.0
    print(f"      ‚ö†Ô∏è  WARNING: No positive examples in training data!")

# ============================================================================
# EXPERIMENT TRACKING SETUP
# ============================================================================

print("\n4. Setting up experiment tracking...")
exp = ExperimentTracking(session=session)
exp.set_experiment('call_quality_prediction_with_ner')
print("   ‚úì Experiment: call_quality_prediction_with_ner")

# ============================================================================
# HELPER FUNCTION: FIND OPTIMAL THRESHOLD
# ============================================================================

def find_optimal_threshold(y_true, y_pred_proba, metric='f1'):
    """Find the threshold that maximizes the specified metric"""
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
    
    if metric == 'f1':
        # Calculate F1 score for each threshold
        f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
        optimal_idx = np.argmax(f1_scores)
        optimal_threshold = thresholds[optimal_idx] if optimal_idx < len(thresholds) else 0.5
        optimal_score = f1_scores[optimal_idx]
        metric_name = "F1"
    elif metric == 'recall':
        # Find threshold that gives good recall while maintaining reasonable precision
        # Target: Recall > 0.7 and maximize precision
        valid_indices = recall > 0.7
        if valid_indices.any():
            valid_precision = precision[valid_indices]
            valid_thresholds = thresholds[valid_indices] if sum(valid_indices) <= len(thresholds) else thresholds
            optimal_idx = np.argmax(valid_precision)
            optimal_threshold = valid_thresholds[optimal_idx] if len(valid_thresholds) > 0 else 0.3
            optimal_score = recall[valid_indices][optimal_idx]
        else:
            optimal_threshold = 0.3
            optimal_score = 0.0
        metric_name = "Recall"
    
    return optimal_threshold, optimal_score, metric_name

# ============================================================================
# EXPERIMENT 1: BASELINE WITH OPTIMIZED HYPERPARAMETERS
# ============================================================================

print("\n" + "="*80)
print("EXPERIMENT 1: Optimized XGBoost with All Features")
print("="*80)

run_name_1 = f"optimized_xgb_all_features_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

with exp.start_run(run_name_1) as run:
    print(f"\n‚úì Started run: {run_name_1}")
    print("Training model...")
    
    model_1 = XGBClassifier(
        n_estimators=300,
        max_depth=4,
        learning_rate=0.03,
        min_child_weight=1,
        subsample=0.7,
        colsample_bytree=0.7,
        gamma=0.1,
        reg_alpha=0.1,
        reg_lambda=1.0,
        scale_pos_weight=scale_pos_weight * 2,
        random_state=42,
        input_cols=all_features,
        label_cols=['"high_quality_call"'],
        output_cols=['PREDICTION']
    )
    
    model_1.fit(train_df)
    print("   ‚úì Model trained")
    
    predictions_hard = model_1.predict(test_df)
    predictions_proba = model_1.predict_proba(test_df)
    print("   ‚úì Predictions made")
    
    # Convert everything together to ensure same length
    test_results = test_df.select(['"high_quality_call"']).to_pandas()
    pred_hard_results = predictions_hard.select(['PREDICTION']).to_pandas()
    pred_proba_results = predictions_proba.select(['PREDICT_PROBA_1']).to_pandas()
    
    # Ensure same length by taking minimum
    min_len = min(len(test_results), len(pred_hard_results), len(pred_proba_results))
    y_true = test_results['high_quality_call'].values[:min_len]
    y_pred_default = pred_hard_results['PREDICTION'].values[:min_len]
    y_pred_proba = pred_proba_results['PREDICT_PROBA_1'].values[:min_len]
    
    print(f"   ‚úì Using {min_len} samples for evaluation")
    
    # Find optimal threshold
    optimal_threshold, optimal_f1, metric_name = find_optimal_threshold(y_true, y_pred_proba, metric='f1')
    print(f"\n   üéØ Optimal Threshold: {optimal_threshold:.3f} (maximizes {metric_name})")
    
    # Apply optimal threshold
    y_pred_optimized = (y_pred_proba >= optimal_threshold).astype(int)
    
    # Calculate metrics with default threshold (0.5)
    accuracy_def = accuracy_score(y_true, y_pred_default)
    precision_def = precision_score(y_true, y_pred_default, zero_division=0)
    recall_def = recall_score(y_true, y_pred_default, zero_division=0)
    f1_def = f1_score(y_true, y_pred_default, zero_division=0)
    
    # Calculate metrics with optimized threshold
    accuracy_opt = accuracy_score(y_true, y_pred_optimized)
    precision_opt = precision_score(y_true, y_pred_optimized, zero_division=0)
    recall_opt = recall_score(y_true, y_pred_optimized, zero_division=0)
    f1_opt = f1_score(y_true, y_pred_optimized, zero_division=0)
    
    try:
        auc = roc_auc_score(y_true, y_pred_proba)
    except Exception as e:
        print(f"   Warning: Could not calculate AUC: {e}")
        auc = 0.0
    
    print(f"\n   üìä Model Performance (Default Threshold = 0.5):")
    print(f"      Accuracy: {accuracy_def:.3f}")
    print(f"      Precision: {precision_def:.3f}")
    print(f"      Recall: {recall_def:.3f}")
    print(f"      F1: {f1_def:.3f}")
    
    print(f"\n   üìä Model Performance (Threshold = {optimal_threshold:.3f}):")
    print(f"      Accuracy: {accuracy_opt:.3f}")
    print(f"      Precision: {precision_opt:.3f}")
    print(f"      Recall: {recall_opt:.3f}")
    print(f"      F1: {f1_opt:.3f}")
    print(f"      AUC: {auc:.3f}")
    
    print(f"\n   üìà Improvement:")
    print(f"      Accuracy: {accuracy_opt - accuracy_def:+.3f}")
    print(f"      Precision: {precision_opt - precision_def:+.3f}")
    print(f"      Recall: {recall_opt - recall_def:+.3f}")
    print(f"      F1: {f1_opt - f1_def:+.3f}")
    
    exp.log_params({
        'model_type': 'xgboost_optimized',
        'n_estimators': 300,
        'max_depth': 4,
        'learning_rate': 0.03,
        'min_child_weight': 1,
        'subsample': 0.7,
        'colsample_bytree': 0.7,
        'gamma': 0.1,
        'reg_alpha': 0.1,
        'reg_lambda': 1.0,
        'scale_pos_weight': scale_pos_weight * 2,
        'random_state': 42,
        'feature_source': 'all_features',
        'feature_count': len(all_features),
        'optimal_threshold': float(optimal_threshold)
    })
    
    # Log only optimized metrics with plain names
    exp.log_metrics({
        'accuracy': accuracy_opt,
        'precision': precision_opt,
        'recall': recall_opt,
        'f1': f1_opt,
        'auc': auc
    })
    print("   ‚úì Metrics logged")
    
    try:
        exp.log_model(model_1, model_name='call_quality_optimized_threshold')
        print("   ‚úì Model logged")
    except Exception as e:
        print(f"   Warning: Could not log model: {e}")
    
    print(f"\n   üíæ Saving optimal threshold: {optimal_threshold:.4f}")
    print(f"   üìù To use this model later, apply threshold: {optimal_threshold:.4f}")
    print(f"   üí° Example: predictions = (probabilities >= {optimal_threshold:.4f}).astype(int)")

print("‚úì Experiment 1 complete")


# ============================================================================
# EXPERIMENT 2: ACOUSTIC FEATURES ONLY
# ============================================================================

print("\n" + "="*80)
print("EXPERIMENT 2: XGBoost with Acoustic Features Only")
print("="*80)

run_name_2 = f"xgb_acoustic_only_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

with exp.start_run(run_name_2) as run:
    print(f"\n‚úì Started run: {run_name_2}")
    print("Training model...")
    
    acoustic_only_features = acoustic_features + interaction_features
    
    # Re-cache test_df to ensure clean state
    test_df_exp2 = test_df.cache_result()
    print(f"   üìä Test set size for this experiment: {test_df_exp2.count()} calls")
    
    model_2 = XGBClassifier(
        n_estimators=300,
        max_depth=4,
        learning_rate=0.03,
        min_child_weight=1,
        subsample=0.7,
        colsample_bytree=0.7,
        gamma=0.1,
        reg_alpha=0.1,
        reg_lambda=1.0,
        scale_pos_weight=scale_pos_weight * 2,
        random_state=42,
        input_cols=acoustic_only_features,
        label_cols=['"high_quality_call"'],
        output_cols=['PREDICTION']
    )
    
    model_2.fit(train_df)
    print("   ‚úì Model trained")
    
    predictions_hard = model_2.predict(test_df_exp2)
    predictions_proba = model_2.predict_proba(test_df_exp2)
    print("   ‚úì Predictions made")
    
    test_results = test_df_exp2.select(['"high_quality_call"']).to_pandas()
    pred_proba_results = predictions_proba.select(['PREDICT_PROBA_1']).to_pandas()
    
    print(f"   üìä Lengths: test={len(test_results)}, predictions={len(pred_proba_results)}")
    
    min_len = min(len(test_results), len(pred_proba_results))
    y_true = test_results['high_quality_call'].values[:min_len]
    y_pred_proba = pred_proba_results['PREDICT_PROBA_1'].values[:min_len]
    
    print(f"   ‚úì Using {min_len} samples for evaluation")
    
    optimal_threshold, _, _ = find_optimal_threshold(y_true, y_pred_proba, metric='f1')
    y_pred_optimized = (y_pred_proba >= optimal_threshold).astype(int)
    
    accuracy = accuracy_score(y_true, y_pred_optimized)
    precision = precision_score(y_true, y_pred_optimized, zero_division=0)
    recall = recall_score(y_true, y_pred_optimized, zero_division=0)
    f1 = f1_score(y_true, y_pred_optimized, zero_division=0)
    
    try:
        auc = roc_auc_score(y_true, y_pred_proba)
    except Exception as e:
        print(f"   Warning: Could not calculate AUC: {e}")
        auc = 0.0
    
    print(f"\n   üìä Model Performance (Threshold = {optimal_threshold:.3f}):")
    print(f"      Accuracy: {accuracy:.3f}")
    print(f"      Precision: {precision:.3f}")
    print(f"      Recall: {recall:.3f}")
    print(f"      F1: {f1:.3f}")
    print(f"      AUC: {auc:.3f}")
    
    exp.log_params({
        'model_type': 'xgboost_optimized',
        'n_estimators': 300,
        'max_depth': 4,
        'learning_rate': 0.03,
        'min_child_weight': 1,
        'subsample': 0.7,
        'colsample_bytree': 0.7,
        'gamma': 0.1,
        'reg_alpha': 0.1,
        'reg_lambda': 1.0,
        'scale_pos_weight': scale_pos_weight * 2,
        'random_state': 42,
        'feature_source': 'acoustic_only',
        'feature_count': len(acoustic_only_features),
        'optimal_threshold': float(optimal_threshold)
    })
    
    exp.log_metrics({
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc
    })
    print("   ‚úì Metrics logged")
    
    try:
        exp.log_model(model_2, model_name='call_quality_acoustic_optimized')
        print("   ‚úì Model logged")
    except Exception as e:
        print(f"   Warning: Could not log model: {e}")

print("‚úì Experiment 2 complete")


# ============================================================================
# EXPERIMENT 3: NER + EMOTION FEATURES
# ============================================================================

print("\n" + "="*80)
print("EXPERIMENT 3: XGBoost with NER + Emotion Features")
print("="*80)

run_name_3 = f"xgb_ner_emotion_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

with exp.start_run(run_name_3) as run:
    print(f"\n‚úì Started run: {run_name_3}")
    print("Training model...")
    
    ner_emotion_features = emotion_features + sentiment_features + ner_features + interaction_features
    
    # Re-cache test_df to ensure clean state
    test_df_exp3 = test_df.cache_result()
    print(f"   üìä Test set size for this experiment: {test_df_exp3.count()} calls")
    
    model_3 = XGBClassifier(
        n_estimators=300,
        max_depth=4,
        learning_rate=0.03,
        min_child_weight=1,
        subsample=0.7,
        colsample_bytree=0.7,
        gamma=0.1,
        reg_alpha=0.1,
        reg_lambda=1.0,
        scale_pos_weight=scale_pos_weight * 2,
        random_state=42,
        input_cols=ner_emotion_features,
        label_cols=['"high_quality_call"'],
        output_cols=['PREDICTION']
    )
    
    model_3.fit(train_df)
    print("   ‚úì Model trained")
    
    predictions_hard = model_3.predict(test_df_exp3)
    predictions_proba = model_3.predict_proba(test_df_exp3)
    print("   ‚úì Predictions made")
    
    test_results = test_df_exp3.select(['"high_quality_call"']).to_pandas()
    pred_proba_results = predictions_proba.select(['PREDICT_PROBA_1']).to_pandas()
    
    print(f"   üìä Lengths: test={len(test_results)}, predictions={len(pred_proba_results)}")
    
    min_len = min(len(test_results), len(pred_proba_results))
    y_true = test_results['high_quality_call'].values[:min_len]
    y_pred_proba = pred_proba_results['PREDICT_PROBA_1'].values[:min_len]
    
    print(f"   ‚úì Using {min_len} samples for evaluation")
    
    optimal_threshold, _, _ = find_optimal_threshold(y_true, y_pred_proba, metric='f1')
    y_pred_optimized = (y_pred_proba >= optimal_threshold).astype(int)
    
    accuracy = accuracy_score(y_true, y_pred_optimized)
    precision = precision_score(y_true, y_pred_optimized, zero_division=0)
    recall = recall_score(y_true, y_pred_optimized, zero_division=0)
    f1 = f1_score(y_true, y_pred_optimized, zero_division=0)
    
    try:
        auc = roc_auc_score(y_true, y_pred_proba)
    except:
        auc = 0.0
    
    print(f"\n   üìä Model Performance (Threshold = {optimal_threshold:.3f}):")
    print(f"      Accuracy: {accuracy:.3f}")
    print(f"      Precision: {precision:.3f}")
    print(f"      Recall: {recall:.3f}")
    print(f"      F1: {f1:.3f}")
    print(f"      AUC: {auc:.3f}")
    
    exp.log_params({
        'model_type': 'xgboost_optimized',
        'n_estimators': 300,
        'max_depth': 4,
        'learning_rate': 0.03,
        'min_child_weight': 1,
        'subsample': 0.7,
        'colsample_bytree': 0.7,
        'gamma': 0.1,
        'reg_alpha': 0.1,
        'reg_lambda': 1.0,
        'scale_pos_weight': scale_pos_weight * 2,
        'random_state': 42,
        'feature_source': 'ner_emotion_sentiment',
        'feature_count': len(ner_emotion_features),
        'optimal_threshold': float(optimal_threshold)
    })
    
    exp.log_metrics({
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc
    })
    print("   ‚úì Metrics logged")
    
    try:
        exp.log_model(model_3, model_name='call_quality_ner_emotion_optimized')
        print("   ‚úì Model logged")
    except Exception as e:
        print(f"   Warning: Could not log model: {e}")

print("‚úì Experiment 3 complete")


# ============================================================================
# EXPERIMENT 4: SMOTE OVERSAMPLING + ALL FEATURES
# ============================================================================

print("\n" + "="*80)
print("EXPERIMENT 4: XGBoost with SMOTE Balanced Data")
print("="*80)

run_name_4 = f"xgb_smote_balanced_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

with exp.start_run(run_name_4) as run:
    print(f"\n‚úì Started run: {run_name_4}")
    print("Step 1: Applying SMOTE to balance training data...")
    
    # Convert train_df to pandas for SMOTE
    train_pd = train_df.select(all_features + ['"high_quality_call"']).to_pandas()
    
    # Remove quotes from column names in pandas
    train_pd.columns = [col.strip('"') for col in train_pd.columns]
    features_no_quotes = [f.strip('"') for f in all_features]
    
    # Separate features and labels
    X_train = train_pd[features_no_quotes].values
    y_train = train_pd['high_quality_call'].values
    
    print(f"   Original training data:")
    print(f"      Positive: {sum(y_train)} ({sum(y_train)/len(y_train):.2%})")
    print(f"      Negative: {len(y_train) - sum(y_train)} ({(len(y_train) - sum(y_train))/len(y_train):.2%})")
    
    # Apply SMOTE
    smote_applied = False
    try:
        from imblearn.over_sampling import SMOTE
        
        # Check if we have enough positive samples for SMOTE
        n_positive = sum(y_train)
        if n_positive < 6:
            print(f"\n   ‚ö†Ô∏è  WARNING: Only {n_positive} positive samples in training set")
            print(f"   SMOTE requires at least 6 samples. Using k_neighbors={max(1, n_positive - 1)}")
            
            if n_positive < 2:
                print(f"   ‚ö†Ô∏è  Cannot apply SMOTE with less than 2 positive samples")
                print(f"   Using original data with increased scale_pos_weight")
                train_df_balanced = train_df
                smote_applied = False
            else:
                smote = SMOTE(random_state=42, k_neighbors=min(n_positive - 1, 5))
                X_train_balanced, y_train_balanced = smote.fit_resample(X_train, y_train)
                smote_applied = True
        else:
            smote = SMOTE(random_state=42)
            X_train_balanced, y_train_balanced = smote.fit_resample(X_train, y_train)
            smote_applied = True
        
        if smote_applied:
            print(f"\n   ‚úì SMOTE applied successfully")
            print(f"   Balanced training data:")
            print(f"      Positive: {sum(y_train_balanced)} ({sum(y_train_balanced)/len(y_train_balanced):.2%})")
            print(f"      Negative: {len(y_train_balanced) - sum(y_train_balanced)} ({(len(y_train_balanced) - sum(y_train_balanced))/len(y_train_balanced):.2%})")
            print(f"      Total samples: {len(y_train_balanced)} (was {len(y_train)})")
            
            # Create balanced dataframe with proper column names for Snowpark
            balanced_data = {}
            for i, feature in enumerate(all_features):
                # Strip quotes and use clean names for pandas
                clean_name = feature.strip('"')
                balanced_data[clean_name] = X_train_balanced[:, i]
            balanced_data['high_quality_call'] = y_train_balanced
            
            # Create pandas DataFrame
            balanced_df = pd.DataFrame(balanced_data)
            
            # Convert back to Snowpark - Snowpark will add proper quoting
            train_df_balanced = session.create_dataframe(balanced_df)
            
            # Rename columns to match original format with quotes
            for col in balanced_df.columns:
                train_df_balanced = train_df_balanced.with_column_renamed(col, f'"{col}"')
        
    except ImportError:
        print("   ‚ö†Ô∏è  WARNING: imbalanced-learn not available, using original data")
        print("   Install with: pip install imbalanced-learn")
        train_df_balanced = train_df
        smote_applied = False
    except Exception as e:
        print(f"   ‚ö†Ô∏è  WARNING: SMOTE failed with error: {e}")
        print("   Using original data with increased scale_pos_weight")
        train_df_balanced = train_df
        smote_applied = False
    
    print("\nStep 2: Training model on balanced data...")
    
    test_df_exp4 = test_df.cache_result()
    print(f"   üìä Test set size for this experiment: {test_df_exp4.count()} calls")
    
    # Adjust scale_pos_weight based on whether SMOTE was applied
    if smote_applied:
        scale_weight = 1.0  # Data is balanced
    else:
        scale_weight = scale_pos_weight * 3  # Increase weight if SMOTE failed
    
    model_4 = XGBClassifier(
        n_estimators=300,
        max_depth=4,
        learning_rate=0.03,
        min_child_weight=1,
        subsample=0.7,
        colsample_bytree=0.7,
        gamma=0.1,
        reg_alpha=0.1,
        reg_lambda=1.0,
        scale_pos_weight=scale_weight,
        random_state=42,
        input_cols=all_features,
        label_cols=['"high_quality_call"'],
        output_cols=['PREDICTION']
    )
    
    model_4.fit(train_df_balanced)
    print("   ‚úì Model trained")
    
    predictions_hard = model_4.predict(test_df_exp4)
    predictions_proba = model_4.predict_proba(test_df_exp4)
    print("   ‚úì Predictions made")
    
    test_results = test_df_exp4.select(['"high_quality_call"']).to_pandas()
    pred_proba_results = predictions_proba.select(['PREDICT_PROBA_1']).to_pandas()
    
    print(f"   üìä Lengths: test={len(test_results)}, predictions={len(pred_proba_results)}")
    
    min_len = min(len(test_results), len(pred_proba_results))
    y_true = test_results['high_quality_call'].values[:min_len]
    y_pred_proba = pred_proba_results['PREDICT_PROBA_1'].values[:min_len]
    
    print(f"   ‚úì Using {min_len} samples for evaluation")
    
    optimal_threshold, _, _ = find_optimal_threshold(y_true, y_pred_proba, metric='f1')
    y_pred_optimized = (y_pred_proba >= optimal_threshold).astype(int)
    
    accuracy = accuracy_score(y_true, y_pred_optimized)
    precision = precision_score(y_true, y_pred_optimized, zero_division=0)
    recall = recall_score(y_true, y_pred_optimized, zero_division=0)
    f1 = f1_score(y_true, y_pred_optimized, zero_division=0)
    
    try:
        auc = roc_auc_score(y_true, y_pred_proba)
    except:
        auc = 0.0
    
    print(f"\n   üìä Model Performance (Threshold = {optimal_threshold:.3f}):")
    print(f"      Accuracy: {accuracy:.3f}")
    print(f"      Precision: {precision:.3f}")
    print(f"      Recall: {recall:.3f}")
    print(f"      F1: {f1:.3f}")
    print(f"      AUC: {auc:.3f}")
    
    exp.log_params({
        'model_type': 'xgboost_optimized_smote',
        'n_estimators': 300,
        'max_depth': 4,
        'learning_rate': 0.03,
        'min_child_weight': 1,
        'subsample': 0.7,
        'colsample_bytree': 0.7,
        'gamma': 0.1,
        'reg_alpha': 0.1,
        'reg_lambda': 1.0,
        'scale_pos_weight': scale_weight,
        'random_state': 42,
        'feature_source': 'all_features_smote_balanced',
        'feature_count': len(all_features),
        'smote_applied': smote_applied,
        'optimal_threshold': float(optimal_threshold)
    })
    
    exp.log_metrics({
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc
    })
    print("   ‚úì Metrics logged")
    
    try:
        exp.log_model(model_4, model_name='call_quality_smote_balanced')
        print("   ‚úì Model logged")
    except Exception as e:
        print(f"   Warning: Could not log model: {e}")

print("‚úì Experiment 4 complete")

print("\n" + "="*80)
print("ALL EXPERIMENTS COMPLETE!")
print("="*80)
print(f"\n‚úì Trained 4 models with threshold optimization")
print(f"‚úì Experiment: call_quality_prediction_with_ner")
print(f"\nModels saved:")
print(f"   1. call_quality_optimized_threshold ({len(all_features)} features)")
print(f"   2. call_quality_acoustic_optimized ({len(acoustic_features + interaction_features)} features)")
print(f"   3. call_quality_ner_emotion_optimized ({len(ner_emotion_features)} features)")
print(f"   4. call_quality_smote_balanced ({len(all_features)} features, SMOTE balanced)")
print("\nüìä KEY OPTIMIZATIONS:")
print(f"   ‚úÖ Optimized hyperparameters for better generalization")
print(f"   ‚úÖ Doubled scale_pos_weight for class imbalance")
print(f"   ‚úÖ Automatic threshold tuning to maximize F1 score")
print(f"   ‚úÖ Comparison of default vs optimized threshold performance")
print(f"   ‚úÖ SMOTE oversampling for balanced training data")
print(f"\nüíæ SAVED FOR PRODUCTION:")
print(f"   ‚úÖ Optimal thresholds logged as parameters in each experiment")
print(f"   ‚úÖ {len(holdout_ids)} holdout call IDs reserved for final validation")
print(f"   ‚úÖ Retrieve threshold from experiment params when deploying")
print("\nüìã Holdout Call IDs (for final validation):")
for i, call_id in enumerate(holdout_ids, 1):
    print(f"   {i}. {call_id}")
print("="*80)

In [None]:
# ============================================================================
# FIND AND ACCESS THE REGISTERED MODEL
# ============================================================================

from snowflake.ml.registry import Registry

print("="*80)
print("FINDING REGISTERED MODEL")
print("="*80)

# ============================================================================
# STEP 1: Check current context
# ============================================================================

print("\n" + "="*80)
print("STEP 1: Current Context")
print("="*80)

current_db = session.sql("SELECT CURRENT_DATABASE()").collect()[0][0]
current_schema = session.sql("SELECT CURRENT_SCHEMA()").collect()[0][0]

print(f"   Current Database: {current_db}")
print(f"   Current Schema: {current_schema}")

# ============================================================================
# STEP 2: Find where models are registered
# ============================================================================

print("\n" + "="*80)
print("STEP 2: Finding Registered Models")
print("="*80)

# Check in NOTEBOOK schema
print("\nüì¶ Checking NOTEBOOK schema...")
try:
    session.sql("USE SCHEMA BUILD25_DEV_TO_PRODUCTION.NOTEBOOK").collect()
    registry_notebook = Registry(session=session)
    
    models_notebook = registry_notebook.show_models()
    
    if len(models_notebook) > 0:
        print(f"   ‚úì Found {len(models_notebook)} model(s) in NOTEBOOK schema:")
        models_notebook.show()
        
        # Check for our specific model
        model_names = [row['name'] for row in models_notebook.collect()]
        if 'call_quality_smote_balanced' in model_names:
            print(f"   ‚úÖ Found 'call_quality_smote_balanced' in NOTEBOOK schema!")
            model_schema = "NOTEBOOK"
        else:
            print(f"   Available models: {model_names}")
    else:
        print(f"   No models in NOTEBOOK schema")
        
except Exception as e:
    print(f"   Error: {e}")

# Check in DATA schema
print("\nüì¶ Checking DATA schema...")
try:
    session.sql("USE SCHEMA BUILD25_DEV_TO_PRODUCTION.DATA").collect()
    registry_data = Registry(session=session)
    
    models_data = registry_data.show_models()
    
    if len(models_data) > 0:
        print(f"   ‚úì Found {len(models_data)} model(s) in DATA schema:")
        models_data.show()
        
        model_names = [row['name'] for row in models_data.collect()]
        if 'call_quality_smote_balanced' in model_names:
            print(f"   ‚úÖ Found 'call_quality_smote_balanced' in DATA schema!")
            model_schema = "DATA"
        else:
            print(f"   Available models: {model_names}")
    else:
        print(f"   No models in DATA schema")
        
except Exception as e:
    print(f"   Error: {e}")

# ============================================================================
# STEP 3: Use the correct schema and load model
# ============================================================================

print("\n" + "="*80)
print("STEP 3: Loading Model from Correct Schema")
print("="*80)

# The model is likely in NOTEBOOK schema (where training happened)
try:
    session.sql("USE SCHEMA BUILD25_DEV_TO_PRODUCTION.NOTEBOOK").collect()
    print(f"‚úì Switched to NOTEBOOK schema")
    
    registry = Registry(session=session)
    model_ref = registry.get_model('call_quality_smote_balanced')
    mv = model_ref.default
    
    print(f"‚úÖ Model loaded successfully from NOTEBOOK schema!")
    print(f"   Model: call_quality_smote_balanced")
    
except Exception as e:
    print(f"‚ùå Error loading model: {e}")
    
    # Show all available models
    print("\nüìã All available models:")
    try:
        all_models = registry.show_models()
        all_models.show()
    except:
        print("   Could not list models")
    
    raise

# ============================================================================
# STEP 4: Deploy service to DATA schema
# ============================================================================

print("\n" + "="*80)
print("STEP 4: Deploying Service to DATA Schema")
print("="*80)

service_name = "call_quality_prediction_service"
compute_pool_name = "SYSTEM_COMPUTE_POOL_GPU"

# Even though model is in NOTEBOOK, we can deploy service to DATA schema
# by switching context before deployment

# First, drop any existing services
print("\n   Cleaning up existing services...")

for schema in ['NOTEBOOK', 'DATA']:
    try:
        session.sql(f"""
            DROP SERVICE IF EXISTS BUILD25_DEV_TO_PRODUCTION.{schema}.{service_name}
        """).collect()
        print(f"   ‚úì Dropped from {schema}")
    except:
        pass

import time
time.sleep(3)

# Switch to DATA schema for deployment
session.sql("USE SCHEMA BUILD25_DEV_TO_PRODUCTION.DATA").collect()
print(f"\n‚úì Switched to DATA schema for deployment")

# Deploy the service
print(f"\n‚è≥ Deploying service...")

try:
    deployment = mv.create_service(
        service_name=service_name,
        service_compute_pool=compute_pool_name,
        max_instances=1,
        force_rebuild=True
    )
    
    print(f"\n‚úÖ Service deployed successfully!")
    print(f"   Service: BUILD25_DEV_TO_PRODUCTION.DATA.{service_name}")
    print(f"   Model source: BUILD25_DEV_TO_PRODUCTION.NOTEBOOK.call_quality_smote_balanced")
    
except Exception as e:
    print(f"\n‚ùå Deployment failed: {e}")
    
    # Try deploying to NOTEBOOK schema instead
    print(f"\n   Trying deployment to NOTEBOOK schema...")
    
    try:
        session.sql("USE SCHEMA BUILD25_DEV_TO_PRODUCTION.NOTEBOOK").collect()
        
        deployment = mv.create_service(
            service_name=service_name,
            service_compute_pool=compute_pool_name,
            max_instances=1,
            force_rebuild=True
        )
        
        print(f"\n‚úÖ Service deployed to NOTEBOOK schema!")
        print(f"   Service: BUILD25_DEV_TO_PRODUCTION.NOTEBOOK.{service_name}")
        
    except Exception as e2:
        print(f"\n‚ùå Alternative also failed: {e2}")
        raise

# ============================================================================
# STEP 5: Verify deployment
# ============================================================================

print("\n" + "="*80)
print("STEP 5: Verifying Deployment")
print("="*80)

time.sleep(3)

print("\nüìä All services in database:")
try:
    all_services = session.sql("""
        SHOW SERVICES IN DATABASE BUILD25_DEV_TO_PRODUCTION
    """).collect()
    
    for svc in all_services:
        schema = svc['schema_name']
        status_icon = "‚úÖ" if svc['state'] == 'READY' else "‚è≥" if svc['state'] == 'STARTING' else "‚ùå"
        print(f"   {status_icon} {schema}.{svc['name']} - State: {svc['state']}")
        
except Exception as e:
    print(f"   Note: {e}")

# ============================================================================
# STEP 6: Wait for service
# ============================================================================

print("\n" + "="*80)
print("STEP 6: Waiting for Service to be Ready")
print("="*80)

max_checks = 12
service_ready = False
service_schema = None

for i in range(max_checks):
    try:
        # Check in both schemas
        for schema in ['DATA', 'NOTEBOOK']:
            status = session.sql(f"""
                SHOW SERVICES LIKE '{service_name}' IN SCHEMA BUILD25_DEV_TO_PRODUCTION.{schema}
            """).collect()
            
            if status:
                state = status[0]['state']
                service_schema = schema
                elapsed = i * 15
                print(f"   [{elapsed}s] {schema}: {state}")
                
                if state == 'READY':
                    print(f"\n‚úÖ Service is READY in {schema} schema!")
                    service_ready = True
                    break
                elif state in ['FAILED', 'ERROR']:
                    print(f"\n‚ùå Service failed in {schema}: {state}")
                    break
        
        if service_ready:
            break
            
    except Exception as e:
        pass
    
    if i < max_checks - 1:
        time.sleep(15)

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "="*80)
print("‚úÖ SETUP COMPLETE")
print("="*80)

print(f"""
üìä Summary:
   Model Location: BUILD25_DEV_TO_PRODUCTION.NOTEBOOK.call_quality_smote_balanced
   Service Location: BUILD25_DEV_TO_PRODUCTION.{service_schema if service_schema else 'TBD'}.{service_name}
   Service Status: {'‚úÖ READY' if service_ready else '‚è≥ Check manually'}
   
üìù To make predictions, use:
   
   SELECT 
       "call_id",
       BUILD25_DEV_TO_PRODUCTION.{service_schema if service_schema else '[SCHEMA]'}.{service_name}!PREDICT(*) as prediction
   FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features
   WHERE "call_id" = 'your-call-id'
   LIMIT 5;

""")

print("="*80)

In [None]:
# ============================================================================
# MAKE PREDICTIONS USING mv.run() METHOD
# ============================================================================

import snowflake.snowpark as snowpark
from snowflake.ml.registry import Registry
from snowflake.snowpark.functions import col, when, lit, current_timestamp

print("="*80)
print("MAKE PREDICTIONS USING MODEL VERSION")
print("="*80)

# ============================================================================
# STEP 1: Load Model Version
# ============================================================================

print("\n" + "="*80)
print("STEP 1: Loading Model")
print("="*80)

# Initialize registry with correct schema
reg = Registry(
    session=session, 
    database_name='BUILD25_DEV_TO_PRODUCTION', 
    schema_name='NOTEBOOK'
)

print("‚úì Registry initialized")
print("   Database: BUILD25_DEV_TO_PRODUCTION")
print("   Schema: NOTEBOOK")

# Get the specific model version
try:
    mv = reg.get_model('CALL_QUALITY_SMOTE_BALANCED').version('SERIOUS_INSECT_1')
    print("\n‚úÖ Model version loaded:")
    print("   Model: CALL_QUALITY_SMOTE_BALANCED")
    print("   Version: SERIOUS_INSECT_1")
except Exception as e:
    print(f"\n‚ùå Error loading model: {e}")
    
    # Show available models
    print("\nüì¶ Available models:")
    try:
        models = reg.show_models()
        models.show()
    except:
        pass
    
    raise

# ============================================================================
# STEP 2: Load Your 5 Calls
# ============================================================================

print("\n" + "="*80)
print("STEP 2: Loading Target Calls")
print("="*80)

call_ids = [
    '0724ec7f-da66-420a-bfdf-925247fd9041',
    '09f444c8-b316-4094-a088-aafdb8269c55',
    '196d451b-ab6d-4863-9b5e-659089ccb58c',
    '1b8cb9f5-948a-4aed-b170-8e2b64cb3931',
    '20d3da90-c70b-47f7-a57b-f95139a1fd63'
]

call_ids_str = "', '".join(call_ids)

print(f"üìã Target: {len(call_ids)} calls\n")

# Load the input data
input_dataframe = session.sql(f"""
    SELECT 
        f.*,
        o.call_resolved,
        o.customer_satisfaction_score,
        CASE 
            WHEN o.call_resolved = 1 AND o.customer_satisfaction_score >= 4
            THEN 1 
            ELSE 0 
        END as "actual_high_quality_call"
    FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features f
    LEFT JOIN BUILD25_DEV_TO_PRODUCTION.DATA.call_outcomes o 
        ON f."call_id" = o."call_id"
    WHERE f."call_id" IN ('{call_ids_str}')
""")

call_count = input_dataframe.count()
print(f"‚úì Loaded {call_count} calls")

if call_count == 0:
    print("‚ùå No calls found!")
    raise Exception("No calls found with these IDs")

# Show sample
print("\nüìã Sample input data:")
input_dataframe.select(
    '"call_id"',
    '"speaking_rate_wpm"',
    '"avg_emotion_score"',
    '"sentiment_score"',
    '"actual_high_quality_call"'
).show(3)

# ============================================================================
# STEP 3: Make Predictions Using Service
# ============================================================================

print("\n" + "="*80)
print("STEP 3: Making Predictions")
print("="*80)

service_name = 'CALL_QUALITY_PREDICTION_SERVICE'

print(f"üîÑ Running predictions...")
print(f"   Service: {service_name}")
print(f"   Function: PREDICT")

predictions = None

try:
    # Use mv.run() with the service
    predictions = mv.run(
        input_dataframe, 
        function_name='PREDICT',
        service_name=service_name
    )
    
    print("‚úÖ Predictions completed!")
    
except Exception as e:
    error_msg = str(e)
    print(f"‚ùå Error with PREDICT: {error_msg[:300]}")
    
    # Try PREDICT_PROBA
    print("\n   Trying PREDICT_PROBA...")
    try:
        predictions = mv.run(
            input_dataframe, 
            function_name='PREDICT_PROBA',
            service_name=service_name
        )
        print("‚úÖ PREDICT_PROBA worked!")
        
    except Exception as e2:
        print(f"‚ùå PREDICT_PROBA also failed: {str(e2)[:300]}")
        
        # Try without service_name (use default)
        print("\n   Trying default service...")
        try:
            predictions = mv.run(
                input_dataframe, 
                function_name='PREDICT'
            )
            print("‚úÖ Default service worked!")
        except Exception as e3:
            print(f"‚ùå All methods failed: {str(e3)[:300]}")
            raise

if predictions is None:
    print("\n‚ùå Could not make predictions")
    raise Exception("Prediction failed")

# ============================================================================
# STEP 4: Display Results
# ============================================================================

print("\n" + "="*80)
print("STEP 4: Prediction Results")
print("="*80)

print("\nüìä Raw Predictions:")
predictions.show()

# Show column names
print("\nüìã Output columns:")
for col_name in predictions.columns:
    print(f"   - {col_name}")

# Convert to pandas for detailed analysis
results_pd = predictions.to_pandas()

print(f"\n‚úì Got {len(results_pd)} predictions")

# Detailed results
print("\nüìä Detailed Results:\n")

correct_count = 0

for idx, row in results_pd.iterrows():
    call_id = row.get('call_id', 'Unknown')
    call_id_short = call_id[:8] + "..." if len(call_id) > 8 else call_id
    
    # Find prediction column
    pred_cols = [c for c in results_pd.columns if 'PREDICTION' in c.upper() or 'OUTPUT' in c.upper()]
    
    if pred_cols:
        pred_col = pred_cols[0]
        pred_value = row[pred_col]
        
        actual_value = row.get('actual_high_quality_call', 'N/A')
        
        is_correct = str(pred_value) == str(actual_value)
        if is_correct:
            correct_count += 1
        
        match_icon = "‚úÖ" if is_correct else "‚ùå"
        
        print(f"{match_icon} Call {idx+1}: {call_id_short}")
        print(f"   Predicted:    {pred_value}")
        print(f"   Actual:       {actual_value}")
        print(f"   Resolved:     {row.get('call_resolved', 'N/A')}")
        print(f"   Satisfaction: {row.get('customer_satisfaction_score', 'N/A')}")
        print()

accuracy = correct_count / len(results_pd) if len(results_pd) > 0 else 0
print(f"üìà Accuracy: {correct_count}/{len(results_pd)} ({accuracy*100:.1f}%)")

# ============================================================================
# STEP 5: Save for Monitoring
# ============================================================================

print("\n" + "="*80)
print("STEP 5: Saving for Monitoring")
print("="*80)

try:
    optimal_threshold = 0.45
    
    # Add metadata columns
    predictions_with_meta = predictions.with_column(
        'TIMESTAMP',
        current_timestamp()
    ).with_column(
        'model_threshold',
        lit(optimal_threshold)
    ).with_column(
        'model_name',
        lit('CALL_QUALITY_SMOTE_BALANCED')
    ).with_column(
        'model_version',
        lit('SERIOUS_INSECT_1')
    ).with_column(
        'service_name',
        lit(service_name)
    )
    
    # Find prediction column
    pred_cols = [c for c in predictions_with_meta.columns if 'PREDICTION' in c.upper()]
    pred_col_name = pred_cols[0] if pred_cols else 'PREDICTION'
    
    # Format for monitoring
    monitoring_df = predictions_with_meta.select(
        col('"call_id"').alias('ID'),
        col('TIMESTAMP'),
        col(pred_col_name).alias('OUTPUT_ENCODED'),
        col('"actual_high_quality_call"').alias('LABEL_ENCODED'),
        col('model_threshold'),
        col('model_name'),
        col('model_version'),
        col('service_name')
    )
    
    # Save
    monitoring_df.write.mode('overwrite').save_as_table(
        'BUILD25_DEV_TO_PRODUCTION.DATA.call_quality_predictions_monitored'
    )
    
    print("‚úÖ Saved to: BUILD25_DEV_TO_PRODUCTION.DATA.call_quality_predictions_monitored")
    
    # Verify
    saved_count = session.sql("""
        SELECT COUNT(*) as cnt 
        FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_quality_predictions_monitored
    """).collect()[0]['CNT']
    
    print(f"‚úì Table contains {saved_count} records")
    
    # Show sample
    print("\nüìã Sample from monitoring table:")
    session.sql("""
        SELECT * 
        FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_quality_predictions_monitored
        LIMIT 3
    """).show()
    
except Exception as e:
    print(f"‚ö†Ô∏è  Error saving: {e}")
    print("   Predictions were made but couldn't save to monitoring table")

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "="*80)
print("‚úÖ PREDICTIONS COMPLETE!")
print("="*80)

print(f"""
üìä Summary:
   ‚úì Model: CALL_QUALITY_SMOTE_BALANCED
   ‚úì Version: SERIOUS_INSECT_1
   ‚úì Service: {service_name}
   ‚úì Method: mv.run()
   
   ‚úì Predictions: {call_count} calls
   ‚úì Accuracy: {accuracy*100:.1f}%
   ‚úì Saved to: call_quality_predictions_monitored

üìù Next Step:
   Run the monitoring setup script to create the model monitor!

üí° What we did:
   1. Loaded model version from registry
   2. Used mv.run() with service name
   3. Made predictions on 5 calls
   4. Saved results for monitoring

""")

print("="*80)

In [None]:
USE ROLE ACCOUNTADMIN;
USE DATABASE BUILD25_DEV_TO_PRODUCTION;
USE SCHEMA DATA;

-- Recreate the baseline using the NTZ view/column names
CREATE OR REPLACE TABLE call_quality_baseline AS
SELECT
  ID,                                -- keep as-is (matches source)
  event_time_ntz,                    -- NTZ timestamp (matches source)
  OUTPUT_ENCODED,                    -- keep as-is unless you see a type mismatch
  CAST(LABEL_ENCODED AS NUMBER(38,0)) AS LABEL_ENCODED  -- <-- fix here
FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_quality_predictions_monitored_ntz
ORDER BY event_time_ntz DESC
LIMIT 100;

USE SCHEMA NOTEBOOK;

DROP MODEL MONITOR IF EXISTS call_quality_monitor;

CREATE MODEL MONITOR call_quality_monitor
WITH
    MODEL = "CALL_QUALITY_SMOTE_BALANCED"
    VERSION = 'SERIOUS_INSECT_1'
    FUNCTION = 'PREDICT'
    SOURCE = BUILD25_DEV_TO_PRODUCTION.DATA.call_quality_predictions_monitored_ntz
    BASELINE = BUILD25_DEV_TO_PRODUCTION.DATA.call_quality_baseline
    TIMESTAMP_COLUMN = event_time_ntz
    ID_COLUMNS = ('ID')
    PREDICTION_CLASS_COLUMNS = ('OUTPUT_ENCODED')
    ACTUAL_CLASS_COLUMNS = ('LABEL_ENCODED')
    WAREHOUSE = JAMES_XS
    REFRESH_INTERVAL = '1 hour'
    AGGREGATION_WINDOW = '1 day';

In [None]:
"""
Setup Online Feature Store
Creates feature views, waits for materialization, then enables online serving
"""

import snowflake.snowpark as snowpark
from snowflake.ml.feature_store import FeatureStore, Entity, OnlineConfig
from snowflake.ml.feature_store.feature_view import FeatureView, StoreType
from snowflake.snowpark.functions import col
import time

current_warehouse = session.get_current_warehouse()

print("="*80)
print("ONLINE FEATURE STORE SETUP - PROPER ORDER")
print("="*80)

# ============================================================================
# PHASE 1: Setup Offline First (Dynamic Tables)
# ============================================================================

print("\n" + "="*80)
print("PHASE 1: CREATE FEATURE VIEWS WITH ONLINE CONFIG")
print("="*80)

print("\n1. Creating/connecting to Feature Store...")
fs = FeatureStore(
    session=session,
    database="BUILD25_DEV_TO_PRODUCTION",
    name="CALL_QUALITY_FS",
    default_warehouse=current_warehouse,
    creation_mode="create_if_not_exist"
)
print("   ‚úì Feature Store ready")

print("\n2. Registering entity...")
try:
    call_entity = Entity(
        name="call",
        join_keys=["call_id"],
        desc="Individual customer service call"
    )
    call_entity = fs.register_entity(call_entity)
    print("   ‚úì Entity registered")
except Exception as e:
    call_entity = fs.get_entity("call")
    print("   ‚úì Entity exists")

print("\n3. Creating feature views WITH online config from start...")

# Acoustic Features
print("   Creating acoustic_features with online serving...")
session.use_schema('CALL_QUALITY_FS')

acoustic_df = session.sql("""
    SELECT 
        "call_id" as call_id,
        "speaking_rate_wpm" as speaking_rate_wpm,
        "speech_rate_variability" as speech_rate_variability,
        "avg_pause_duration_sec" as avg_pause_duration_sec,
        "pause_frequency_per_min" as pause_frequency_per_min,
        "avg_pitch_hz" as avg_pitch_hz,
        "pitch_variance" as pitch_variance,
        "pitch_range_hz" as pitch_range_hz,
        "energy_mean" as energy_mean,
        "energy_variance" as energy_variance,
        "dynamic_range_db" as dynamic_range_db,
        "spectral_centroid" as spectral_centroid,
        "harmonics_to_noise_ratio" as harmonics_to_noise_ratio,
        "jitter" as jitter,
        "shimmer" as shimmer,
        "zero_crossing_rate" as zero_crossing_rate,
        "silence_ratio" as silence_ratio,
        "speech_to_silence_ratio" as speech_to_silence_ratio,
        "interruption_count" as interruption_count,
        "agent_talk_ratio" as agent_talk_ratio,
        "turn_taking_rate" as turn_taking_rate,
        "avg_turn_duration_sec" as avg_turn_duration_sec,
        "processed_at" as timestamp_col
    FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features
""")

acoustic_fv = FeatureView(
    name="acoustic_features",
    entities=[call_entity],
    feature_df=acoustic_df,
    timestamp_col="timestamp_col",
    refresh_freq="1 minute",
    refresh_mode="INCREMENTAL",
    desc="Acoustic features for agent coaching",
    online_config=OnlineConfig(enable=True, target_lag="30 seconds")
)

try:
    acoustic_fv = fs.register_feature_view(
        feature_view=acoustic_fv,
        version="v1",
        overwrite=True
    )
    print("      ‚úì acoustic_features created with online serving")
except Exception as e:
    print(f"      ‚úó Error: {e}")
    print(f"         This may be because online config is in the constructor now")

# Emotion & Sentiment Features
print("   Creating emotion_sentiment_features with online serving...")
emotion_df = session.sql("""
    SELECT 
        "call_id" as call_id,
        "avg_emotion_score" as avg_emotion_score,
        "emotion_volatility" as emotion_volatility,
        "stress_indicators" as stress_indicators,
        "sentiment_score" as sentiment_score,
        "sentiment_label" as sentiment_label,
        "dominant_emotion" as dominant_emotion,
        "processed_at" as timestamp_col
    FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features
""")

emotion_fv = FeatureView(
    name="emotion_sentiment_features",
    entities=[call_entity],
    feature_df=emotion_df,
    timestamp_col="timestamp_col",
    refresh_freq="1 minute",
    refresh_mode="INCREMENTAL",
    desc="Emotion and sentiment features",
    online_config=OnlineConfig(enable=True, target_lag="30 seconds")
)

try:
    emotion_fv = fs.register_feature_view(
        feature_view=emotion_fv,
        version="v1",
        overwrite=True
    )
    print("      ‚úì emotion_sentiment_features created with online serving")
except Exception as e:
    print(f"      ‚úó Error: {e}")

# NER & Content Features
print("   Creating ner_content_features with online serving...")
ner_df = session.sql("""
    SELECT 
        "call_id" as call_id,
        "word_count" as word_count,
        "entity_count" as entity_count,
        "entities_person" as entities_person,
        "entities_org" as entities_org,
        "entities_loc" as entities_loc,
        "processed_at" as timestamp_col
    FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features
""")

ner_fv = FeatureView(
    name="ner_content_features",
    entities=[call_entity],
    feature_df=ner_df,
    timestamp_col="timestamp_col",
    refresh_freq="1 minute",
    refresh_mode="INCREMENTAL",
    desc="NER and content features",
    online_config=OnlineConfig(enable=True, target_lag="30 seconds")
)

try:
    ner_fv = fs.register_feature_view(
        feature_view=ner_fv,
        version="v1",
        overwrite=True
    )
    print("      ‚úì ner_content_features created with online serving")
except Exception as e:
    print(f"      ‚úó Error: {e}")

print("\n   ‚úì All feature views created WITH online config")

# ============================================================================
# PHASE 2: Wait for Online Tables to Materialize
# ============================================================================

print("\n" + "="*80)
print("PHASE 2: WAIT FOR ONLINE FEATURE TABLES TO MATERIALIZE")
print("="*80)

print("\nOnline feature tables need to populate from dynamic tables...")
print("This happens automatically in the background.")
print("Checking every 15 seconds (max 3 minutes)...")

max_wait = 180  # 3 minutes
check_interval = 15
elapsed = 0
all_materialized = False

feature_view_names = ["acoustic_features", "emotion_sentiment_features", "ner_content_features"]

while elapsed < max_wait and not all_materialized:
    time.sleep(check_interval)
    elapsed += check_interval
    
    print(f"\n‚è±Ô∏è  Check at {elapsed}s:")
    ready_count = 0
    
    for fv_name in feature_view_names:
        dt_name = f"{fv_name.upper()}$v1"
        try:
            count = session.sql(f"""
                SELECT COUNT(*) as cnt 
                FROM BUILD25_DEV_TO_PRODUCTION.CALL_QUALITY_FS."{dt_name}"
            """).collect()[0]['CNT']
            
            if count > 0:
                print(f"   ‚úì {dt_name}: {count} rows - READY")
                ready_count += 1
            else:
                print(f"   ‚è≥ {dt_name}: 0 rows - waiting...")
        except Exception as e:
            print(f"   ‚è≥ {dt_name}: Not materialized yet...")
    
    if ready_count == len(feature_view_names):
        all_materialized = True
        print("\n   ‚úÖ All tables have data!")
        break

if not all_materialized:
    print("\n   ‚ö†Ô∏è  Some tables still materializing")
    print("   Online serving may still be initializing in background...")

# ============================================================================
# PHASE 3: Test Online Retrieval
# ============================================================================

print("\n" + "="*80)
print("PHASE 3: TEST ONLINE RETRIEVAL")
print("="*80)

print("\nWaiting 30 seconds for online tables to populate...")
time.sleep(30)

print("\nTesting online feature retrieval...")

try:
    # Get a sample call_id
    sample_call = session.sql("""
        SELECT "call_id" 
        FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features 
        LIMIT 1
    """).collect()[0]['call_id']
    
    print(f"Testing with call_id: {sample_call}")
    
    # Get feature view
    acoustic_fv = fs.get_feature_view("acoustic_features", "v1")
    
    # Try online retrieval - FIXED: Using UPPERCASE feature names that exist in acoustic_features
    features = fs.read_feature_view(
        feature_view=acoustic_fv,
        version="v1",
        keys=[[sample_call]],
        feature_names=["SPEAKING_RATE_WPM", "AVG_PITCH_HZ", "ENERGY_MEAN"],
        store_type=StoreType.ONLINE
    )
    
    print("\n‚úÖ SUCCESS! Online retrieval working!")
    print("\nRetrieved features:")
    features.show()
    
    online_working = True
    
except Exception as e:
    print(f"\n‚úó Online retrieval failed: {e}")
    print("\nThis could mean:")
    print("   - Online tables still populating (wait another 1-2 minutes)")
    print("   - Online serving not available in your environment")
    print("   - Need to grant additional permissions")
    online_working = False

# ============================================================================
# Summary
# ============================================================================

print("\n" + "="*80)
print("SETUP COMPLETE!")
print("="*80)

if online_working:
    print("\n‚úÖ ONLINE MODE: ENABLED AND WORKING!")
    print("\nüìä Configuration:")
    print("   ‚Ä¢ Feature Store: CALL_QUALITY_FS")
    print("   ‚Ä¢ Feature Views: 3 (all with online serving)")
    print("   ‚Ä¢ Target Lag: 30 seconds")
    print("   ‚Ä¢ Mode: ONLINE (low-latency point lookups)")
    
    print("\nüéØ Next Steps:")
    print("   1. Use the ONLINE version of Streamlit dashboard")
    print("   2. Features available with <30 second latency")
    print("   3. Monitor with: fs.get_refresh_history(fv, StoreType.ONLINE)")
    
else:
    print("\n‚ö†Ô∏è  ONLINE MODE: NOT WORKING")
    print("\nüìä Current Status:")
    print("   ‚Ä¢ Feature Store: CALL_QUALITY_FS - ‚úì")
    print("   ‚Ä¢ Feature Views: 3 (offline mode) - ‚úì")
    print("   ‚Ä¢ Dynamic Tables: Materialized - ‚úì")
    print("   ‚Ä¢ Online Serving: Failed to enable - ‚úó")
    
    print("\nüîÑ Options:")
    print("\n   OPTION 1: Wait and Retry")
    print("   - Wait another 2-3 minutes")
    print("   - Run just the Phase 3 section again")
    print("   - Online tables may still be initializing")
    
    print("\n   OPTION 2: Use Offline Mode (RECOMMENDED)")
    print("   - Offline mode works perfectly right now")
    print("   - Query dynamic tables directly (fast)")
    print("   - Production-ready, no preview features")
    print("   - Use the OFFLINE Streamlit dashboard")
    
    print("\n   OPTION 3: Check Permissions")
    print("   - You may need ACCOUNTADMIN to grant:")
    print("   - CREATE ONLINE FEATURE TABLE privilege")
    print("   - Contact your Snowflake admin")

print("\n" + "="*80)

In [None]:
"""
Real-Time Agent Coaching Dashboard
Streamlit app for live call quality monitoring and coaching
"""

import streamlit as st

# Now import everything else
import snowflake.snowpark as snowpark
from snowflake.ml.feature_store import FeatureStore
from snowflake.ml.feature_store.feature_view import StoreType
from snowflake.ml.registry import Registry
import pandas as pd
from datetime import datetime, timedelta
import time

# Custom CSS
st.markdown("""
<style>
    .big-metric {
        font-size: 48px;
        font-weight: bold;
        text-align: center;
    }
    .coaching-alert {
        padding: 15px;
        border-radius: 5px;
        margin: 10px 0;
        font-weight: bold;
    }
    .alert-high {
        background-color: #ff4444;
        color: white;
    }
    .alert-medium {
        background-color: #ffaa00;
        color: white;
    }
    .alert-good {
        background-color: #00cc66;
        color: white;
    }
</style>
""", unsafe_allow_html=True)

# Initialize session state
if 'auto_refresh' not in st.session_state:
    st.session_state.auto_refresh = False

# ============================================================================
# Initialize Connections
# ============================================================================

@st.cache_resource
def get_snowflake_connection():
    """Initialize Snowflake connection"""
    try:
        session = snowpark.Session.builder.getOrCreate()
        return session
    except Exception as e:
        st.error(f"Failed to connect to Snowflake: {e}")
        return None

@st.cache_resource
def get_feature_store(_session):
    """Initialize Feature Store"""
    try:
        # Get current warehouse from session
        current_warehouse = _session.get_current_warehouse()
        
        fs = FeatureStore(
            session=_session,
            database="BUILD25_DEV_TO_PRODUCTION",
            name="CALL_QUALITY_FS",
            default_warehouse=current_warehouse,
            creation_mode="fail_if_not_exist"
        )
        return fs
    except Exception as e:
        st.error(f"Failed to load Feature Store: {e}")
        st.error("Make sure you've run the Feature Store setup script first!")
        return None

# ============================================================================
# Helper Functions
# ============================================================================

def get_active_calls(session):
    """Get list of active calls"""
    try:
        # Get recent calls without ordering by processed_at
        calls = session.sql("""
            SELECT DISTINCT "call_id", "agent_id"
            FROM BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features
            LIMIT 20
        """).to_pandas()
        return calls
    except Exception as e:
        st.error(f"Error loading calls: {e}")
        # Return empty DataFrame with correct columns
        return pd.DataFrame(columns=['call_id', 'agent_id'])

def get_call_features(fs, call_id):
    """Retrieve features for a specific call from online store"""
    try:
        # Get feature views
        acoustic_fv = fs.get_feature_view("acoustic_features", "v1")
        emotion_fv = fs.get_feature_view("emotion_sentiment_features", "v1")
        ner_fv = fs.get_feature_view("ner_content_features", "v1")
        
        # Retrieve features - FIXED: Using UPPERCASE feature names
        acoustic = fs.read_feature_view(
            feature_view=acoustic_fv,
            keys=[[call_id]],
            feature_names=[
                "SPEAKING_RATE_WPM", "SPEECH_RATE_VARIABILITY",
                "INTERRUPTION_COUNT", "AGENT_TALK_RATIO",
                "ENERGY_MEAN", "AVG_PAUSE_DURATION_SEC"
            ],
            store_type=StoreType.ONLINE
        ).to_pandas()
        
        emotion = fs.read_feature_view(
            feature_view=emotion_fv,
            keys=[[call_id]],
            feature_names=[
                "AVG_EMOTION_SCORE", "STRESS_INDICATORS",
                "SENTIMENT_SCORE", "DOMINANT_EMOTION"
            ],
            store_type=StoreType.ONLINE
        ).to_pandas()
        
        ner = fs.read_feature_view(
            feature_view=ner_fv,
            keys=[[call_id]],
            feature_names=["WORD_COUNT", "ENTITY_COUNT"],
            store_type=StoreType.ONLINE
        ).to_pandas()
        
        # Combine
        features = pd.concat([acoustic, emotion, ner], axis=1)
        return features
    except Exception as e:
        st.error(f"Error retrieving features: {e}")
        return None

def generate_coaching_suggestions(features):
    """Generate coaching suggestions based on features"""
    suggestions = []
    priority_scores = []
    
    if features is None or len(features) == 0:
        return [], []
    
    row = features.iloc[0]
    
    # High priority alerts (RED) - FIXED: Using UPPERCASE column names
    if row.get('STRESS_INDICATORS', 0) > 0.6:
        suggestions.append("üî¥ HIGH STRESS DETECTED - Suggest calming breath or brief pause")
        priority_scores.append(3)
    
    if row.get('SENTIMENT_SCORE', 0.5) < 0.3:
        suggestions.append("üî¥ VERY NEGATIVE SENTIMENT - Consider immediate escalation")
        priority_scores.append(3)
    
    # Medium priority alerts (YELLOW)
    if row.get('SPEAKING_RATE_WPM', 150) > 180:
        suggestions.append("‚ö†Ô∏è SPEAKING TOO FAST - Coach to slow down and enunciate")
        priority_scores.append(2)
    
    if row.get('INTERRUPTION_COUNT', 0) > 5:
        suggestions.append("‚ö†Ô∏è EXCESSIVE INTERRUPTIONS - Practice active listening")
        priority_scores.append(2)
    
    agent_talk = row.get('AGENT_TALK_RATIO', 0.5)
    if agent_talk > 0.7:
        suggestions.append("‚ö†Ô∏è AGENT DOMINATING CONVERSATION - Ask open-ended questions")
        priority_scores.append(2)
    elif agent_talk < 0.3:
        suggestions.append("‚ö†Ô∏è CUSTOMER DOMINATING - Redirect conversation politely")
        priority_scores.append(2)
    
    if row.get('AVG_PAUSE_DURATION_SEC', 0) > 3:
        suggestions.append("‚ö†Ô∏è LONG PAUSES - Check system response or provide updates")
        priority_scores.append(2)
    
    # Low priority (GREEN)
    if row.get('SENTIMENT_SCORE', 0.5) > 0.7:
        suggestions.append("‚úÖ POSITIVE SENTIMENT - Great job, keep it up!")
        priority_scores.append(1)
    
    if len(suggestions) == 0:
        suggestions.append("‚úÖ CALL PROGRESSING WELL - No immediate action needed")
        priority_scores.append(1)
    
    return suggestions, priority_scores

def get_quality_prediction(features):
    """Get quality prediction (simplified for demo)"""
    if features is None or len(features) == 0:
        return 0, 0.5
    
    row = features.iloc[0]
    
    # Simple rule-based prediction - FIXED: Using UPPERCASE column names
    score = 0
    if row.get('STRESS_INDICATORS', 0) < 0.5:
        score += 1
    if row.get('SENTIMENT_SCORE', 0.5) > 0.5:
        score += 1
    if row.get('INTERRUPTION_COUNT', 0) < 3:
        score += 1
    if 0.4 < row.get('AGENT_TALK_RATIO', 0.5) < 0.6:
        score += 1
    
    quality = 1 if score >= 3 else 0
    confidence = score / 4
    
    return quality, confidence

# ============================================================================
# Main Dashboard
# ============================================================================

# Header
st.title("üéØ Real-Time Agent Coaching Dashboard")
st.markdown("Live call quality monitoring and instant coaching suggestions")

# Sidebar
with st.sidebar:
    st.header("‚öôÔ∏è Controls")
    
    # Auto-refresh toggle
    auto_refresh = st.checkbox("Auto-refresh (5s)", value=st.session_state.auto_refresh)
    st.session_state.auto_refresh = auto_refresh
    
    if st.button("üîÑ Refresh Now"):
        st.rerun()
    
    st.markdown("---")
    st.header("üìä Filters")
    show_all_calls = st.checkbox("Show all calls", value=True)
    
    st.markdown("---")
    st.info("üí° **Tip:** Select a call to see detailed coaching suggestions")

# Initialize connections
session = get_snowflake_connection()
if session is None:
    st.error("Cannot connect to Snowflake")
    st.stop()

fs = get_feature_store(session)
if fs is None:
    st.error("Cannot load Feature Store")
    st.stop()

# Get active calls
calls_df = get_active_calls(session)

if calls_df.empty:
    st.warning("No active calls found. Please check if data exists in BUILD25_DEV_TO_PRODUCTION.DATA.call_acoustic_features")
    st.stop()

# Verify columns exist
if 'call_id' not in calls_df.columns or 'agent_id' not in calls_df.columns:
    st.error(f"Expected columns not found. Available columns: {calls_df.columns.tolist()}")
    st.stop()

# ============================================================================
# Call Selection
# ============================================================================

st.header("üìû Active Calls")

# Create columns for call selection
col1, col2 = st.columns([3, 1])

with col1:
    selected_call = st.selectbox(
        "Select a call to monitor:",
        options=calls_df['call_id'].tolist(),
        format_func=lambda x: f"Call: {x} | Agent: {calls_df[calls_df['call_id']==x]['agent_id'].values[0]}"
    )

with col2:
    st.metric("Total Active", len(calls_df))

# ============================================================================
# Real-Time Monitoring
# ============================================================================

if selected_call:
    st.markdown("---")
    st.header(f"üé§ Monitoring Call: {selected_call}")
    
    # Get features
    with st.spinner("Loading call data..."):
        features = get_call_features(fs, selected_call)
    
    if features is not None and len(features) > 0:
        # Get prediction
        quality, confidence = get_quality_prediction(features)
        
        # Top metrics row
        col1, col2, col3, col4 = st.columns(4)
        
        with col1:
            quality_color = "üü¢" if quality == 1 else "üî¥"
            st.markdown(f"### {quality_color} Call Quality")
            st.markdown(f"<div class='big-metric'>{'HIGH' if quality == 1 else 'LOW'}</div>", 
                       unsafe_allow_html=True)
            st.progress(confidence)
        
        with col2:
            stress = features.iloc[0].get('STRESS_INDICATORS', 0)
            stress_color = "üî¥" if stress > 0.6 else "üü°" if stress > 0.4 else "üü¢"
            st.markdown(f"### {stress_color} Stress Level")
            st.markdown(f"<div class='big-metric'>{stress:.1%}</div>", unsafe_allow_html=True)
            st.progress(stress)
        
        with col3:
            sentiment = features.iloc[0].get('SENTIMENT_SCORE', 0.5)
            sent_color = "üü¢" if sentiment > 0.6 else "üü°" if sentiment > 0.4 else "üî¥"
            st.markdown(f"### {sent_color} Sentiment")
            st.markdown(f"<div class='big-metric'>{sentiment:.1%}</div>", unsafe_allow_html=True)
            st.progress(sentiment)
        
        with col4:
            interruptions = int(features.iloc[0].get('INTERRUPTION_COUNT', 0))
            int_color = "üî¥" if interruptions > 5 else "üü°" if interruptions > 3 else "üü¢"
            st.markdown(f"### {int_color} Interruptions")
            st.markdown(f"<div class='big-metric'>{interruptions}</div>", unsafe_allow_html=True)
        
        st.markdown("---")
        
        # Coaching Suggestions
        col1, col2 = st.columns([2, 1])
        
        with col1:
            st.header("üí° Real-Time Coaching Suggestions")
            suggestions, priorities = generate_coaching_suggestions(features)
            
            for suggestion, priority in zip(suggestions, priorities):
                if priority == 3:
                    st.markdown(f"<div class='coaching-alert alert-high'>{suggestion}</div>", 
                               unsafe_allow_html=True)
                elif priority == 2:
                    st.markdown(f"<div class='coaching-alert alert-medium'>{suggestion}</div>", 
                               unsafe_allow_html=True)
                else:
                    st.markdown(f"<div class='coaching-alert alert-good'>{suggestion}</div>", 
                               unsafe_allow_html=True)
        
        with col2:
            st.header("üìä Call Metrics")
            
            row = features.iloc[0]
            
            st.metric("Speaking Rate", f"{row.get('SPEAKING_RATE_WPM', 0):.0f} WPM")
            st.metric("Agent Talk Ratio", f"{row.get('AGENT_TALK_RATIO', 0):.1%}")
            st.metric("Pause Duration", f"{row.get('AVG_PAUSE_DURATION_SEC', 0):.1f}s")
            st.metric("Word Count", f"{int(row.get('WORD_COUNT', 0))}")
            st.metric("Entities", f"{int(row.get('ENTITY_COUNT', 0))}")
        
        # Detailed Features (Expandable)
        with st.expander("üîç View All Features"):
            st.dataframe(features.T, use_container_width=True)
    
    else:
        st.error("Could not load features for this call")

# Auto-refresh logic
if st.session_state.auto_refresh:
    time.sleep(5)
    st.rerun()

# Footer
st.markdown("---")
st.caption("üéØ Real-Time Agent Coaching Dashboard | Powered by Snowflake ML & Feature Store")