In [1]:
%pip install -q streamlit==1.46.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m46.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.1/79.1 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
%%writefile app.py
import os
import streamlit as st
import pandas as pd
import numpy as np
from typing import List, Tuple, Optional
import plotly.express as px
import plotly.graph_objects as go
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
import seaborn as sns
import matplotlib.pyplot as plt
import time
import hashlib
import pickle
from pathlib import Path
import subprocess
import sys
import warnings
warnings.filterwarnings('ignore')

# -------------------- Configuration --------------------
st.set_page_config(
    page_title="📧 Email Classification Explorer",
    page_icon="📧",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS with Royal Green Theme
st.markdown("""
<style>
    /* Main color palette */
    :root {
        --royal-dark: #003A00;
        --royal-medium: #005500;
        --royal-light: #007000;
        --royal-pale: #228B22;
        --royal-bg: #E8F5E8;
        --royal-accent: #00AA00;
    }

    .main-header {
        font-size: 2.5rem;
        color: var(--royal-dark);
        text-align: center;
        margin-bottom: 2rem;
        text-shadow: 2px 2px 4px rgba(0, 58, 0, 0.2);
        font-weight: 600;
    }

    .metric-card {
        background: linear-gradient(135deg, #005500 0%, #003A00 100%);
        padding: 1.2rem;
        border-radius: 12px;
        color: white;
        text-align: center;
        margin: 0.5rem 0;
        box-shadow: 0 4px 6px rgba(0, 58, 0, 0.3);
        transition: transform 0.2s;
    }

    .metric-card:hover {
        transform: translateY(-2px);
        box-shadow: 0 6px 12px rgba(0, 58, 0, 0.4);
    }

    .stAlert > div {
        background: linear-gradient(90deg, #228B22, #005500);
        color: white;
        border-radius: 8px;
        border: none;
    }

    .email-result {
        border-left: 4px solid var(--royal-medium);
        padding: 1rem;
        margin: 0.5rem 0;
        background: var(--royal-bg);
        border-radius: 0 10px 10px 0;
        box-shadow: 0 2px 4px rgba(0, 58, 0, 0.1);
    }

    .similarity-score {
        background: var(--royal-pale);
        padding: 0.3rem 0.8rem;
        border-radius: 20px;
        font-weight: 600;
        color: white;
        font-size: 0.9rem;
    }

    /* Sidebar styling */
    .css-1d391kg {
        background-color: var(--royal-bg);
    }

    /* Button styling */
    .stButton > button {
        background: linear-gradient(135deg, #005500 0%, #003A00 100%);
        color: white;
        border: none;
        padding: 0.5rem 1rem;
        font-weight: 500;
        transition: all 0.3s;
    }

    .stButton > button:hover {
        background: linear-gradient(135deg, #007000 0%, #005500 100%);
        box-shadow: 0 4px 8px rgba(0, 58, 0, 0.3);
    }

    /* Success/Warning/Error messages */
    .stSuccess {
        background-color: var(--royal-pale);
        color: white;
        border-left: 4px solid var(--royal-medium);
    }

    .stWarning {
        background-color: #F0E68C;
        color: #5C4033;
        border-left: 4px solid #DAA520;
    }

    .stError {
        background-color: #FFE4E1;
        color: #8B0000;
        border-left: 4px solid #DC143C;
    }

    /* Tab styling */
    .stTabs [data-baseweb="tab-list"] {
        background-color: var(--royal-bg);
        border-radius: 8px;
    }

    .stTabs [data-baseweb="tab"] {
        color: var(--royal-dark);
        font-weight: 500;
    }

    .stTabs [aria-selected="true"] {
        background-color: var(--royal-medium);
        color: white;
        border-radius: 6px;
    }

    /* Expander styling */
    .streamlit-expanderHeader {
        background-color: var(--royal-pale);
        color: white;
        border-radius: 8px;
        font-weight: 500;
    }

    /* Input field styling */
    .stTextInput > div > div > input,
    .stTextArea > div > div > textarea {
        border: 2px solid var(--royal-light);
        border-radius: 8px;
    }

    .stTextInput > div > div > input:focus,
    .stTextArea > div > div > textarea:focus {
        border-color: var(--royal-medium);
        box-shadow: 0 0 0 3px rgba(0, 85, 0, 0.1);
    }

    /* Slider styling */
    .stSlider > div > div > div {
        background-color: var(--royal-light);
    }

    .stSlider > div > div > div > div {
        background-color: var(--royal-medium);
    }
</style>
""", unsafe_allow_html=True)

# Royal Green color scheme for Plotly charts
ROYAL_COLOR_SCALE = [
    '#E8F5E8', '#A8D5A8', '#7FBF7F', '#228B22',
    '#007000', '#005500', '#003A00', '#002600'
]

# -------------------- Cache Directory --------------------
CACHE_DIR = Path("./.cache")
CACHE_DIR.mkdir(exist_ok=True)

def main():
    """Main application function"""

    # -------------------- Auto Download Sample Data --------------------
    def install_and_import_gdown():
        """Install gdown if not available and import it"""
        try:
            import gdown
            return gdown
        except ImportError:
            with st.spinner("📦 Installing gdown package..."):
                subprocess.check_call([sys.executable, "-m", "pip", "install", "gdown"])
            import gdown
            return gdown

    @st.cache_data
    def download_sample_data():
        """Download sample data from Google Drive"""
        sample_file_path = CACHE_DIR / "sample_emails.csv"

        # Check if file already exists
        if sample_file_path.exists():
            try:
                df = pd.read_csv(sample_file_path)
                df_processed = process_and_validate_dataframe(df)
                if len(df_processed) > 0:
                    return df_processed, "✅ Loaded cached sample data"
            except:
                # If file is corrupted, download again
                sample_file_path.unlink()

        # Method 1: Try with gdown
        try:
            # Install gdown if needed (without caching this function)
            try:
                import gdown
            except ImportError:
                with st.spinner("📦 Installing gdown package..."):
                    subprocess.check_call([sys.executable, "-m", "pip", "install", "gdown"])
                import gdown

            # Download file from Google Drive
            file_id = "1N7rk-kfnDFIGMeX0ROVTjKh71gcgx-7R"
            url = f"https://drive.google.com/uc?id={file_id}"

            with st.spinner("📥 Downloading sample data from Google Drive..."):
                gdown.download(url, str(sample_file_path), quiet=True)

            # Load and process the downloaded file
            df = pd.read_csv(sample_file_path)
            df_processed = process_and_validate_dataframe(df)

            if len(df_processed) > 0:
                return df_processed, f"✅ Downloaded sample data: {len(df_processed)} emails"
            else:
                raise ValueError("No valid data after processing")

        except Exception as gdown_error:
            st.warning(f"⚠️ gdown method failed: {str(gdown_error)}")

            # Method 2: Try with requests as fallback
            try:
                import requests

                file_id = "1N7rk-kfnDFIGMeX0ROVTjKh71gcgx-7R"
                download_url = f"https://drive.google.com/uc?export=download&id={file_id}"

                with st.spinner("📥 Trying alternative download method..."):
                    response = requests.get(download_url)
                    response.raise_for_status()

                    # Save the file
                    with open(sample_file_path, 'wb') as f:
                        f.write(response.content)

                # Load and process
                df = pd.read_csv(sample_file_path)
                df_processed = process_and_validate_dataframe(df)

                if len(df_processed) > 0:
                    return df_processed, f"✅ Downloaded sample data (alternative method): {len(df_processed)} emails"
                else:
                    raise ValueError("No valid data after processing")

            except Exception as requests_error:
                st.error(f"❌ All download methods failed:")
                st.error(f"• gdown error: {str(gdown_error)}")
                st.error(f"• requests error: {str(requests_error)}")
                st.error("Please check your internet connection or the Google Drive file ID.")
                return pd.DataFrame(), "❌ Failed to load sample data"

    def process_and_validate_dataframe(df):
        """Process and validate dataframe, attempting to map columns to expected format"""
        if df.empty:
            return pd.DataFrame()

        # Debug: Show what we actually got
        st.sidebar.markdown("### 🔍 Debug Info")
        st.sidebar.write("**Downloaded columns:**")
        st.sidebar.write(list(df.columns))
        st.sidebar.write("**Shape:**", df.shape)
        if len(df) > 0:
            st.sidebar.write("**First few rows:**")
            st.sidebar.dataframe(df.head(3))

        # Try to map columns intelligently
        column_mapping = {}

        # Look for ID column (optional)
        id_candidates = ['id', 'ID', 'index', 'Index', 'message_id', 'email_id']
        for col in df.columns:
            if col.lower() in [c.lower() for c in id_candidates]:
                column_mapping['id'] = col
                break

        # Look for email/text/message column
        email_candidates = ['message', 'Message', 'email', 'Email', 'text', 'Text', 'content', 'Content', 'body', 'Body']
        for col in df.columns:
            if col.lower() in [c.lower() for c in email_candidates]:
                column_mapping['email'] = col
                break

        # Look for label/category column
        label_candidates = ['category', 'Category', 'label', 'Label', 'class', 'Class', 'type', 'Type', 'spam', 'Spam']
        for col in df.columns:
            if col.lower() in [c.lower() for c in label_candidates]:
                column_mapping['label'] = col
                break

        # Show mapping results
        st.sidebar.write("**Column mapping found:**")
        st.sidebar.write(column_mapping)

        # Handle different scenarios based on available columns
        processed_df = None

        if len(df.columns) == 2:
            # Perfect for Category, Message format
            if 'email' in column_mapping and 'label' in column_mapping:
                # We found both email and label columns
                processed_df = pd.DataFrame({
                    'id': range(1, len(df) + 1),
                    'email': df[column_mapping['email']],
                    'label': df[column_mapping['label']]
                })
                st.sidebar.success("✅ Perfect! Auto-detected Category/Message format")
            else:
                # Fallback: assume first column is label, second is email
                processed_df = pd.DataFrame({
                    'id': range(1, len(df) + 1),
                    'email': df.iloc[:, 1],  # Second column as email/message
                    'label': df.iloc[:, 0]   # First column as label/category
                })
                st.sidebar.warning("⚠️ Using columns as: Category (col 1) → label, Message (col 2) → email")

        elif len(df.columns) >= 3:
            # 3+ columns: try to use mapping, fallback to first 3
            if all(key in column_mapping for key in ['id', 'email', 'label']):
                processed_df = pd.DataFrame({
                    'id': df[column_mapping['id']],
                    'email': df[column_mapping['email']],
                    'label': df[column_mapping['label']]
                })
                st.sidebar.success("✅ Found all required columns with mapping")
            elif 'email' in column_mapping and 'label' in column_mapping:
                # Found email and label, generate ID
                processed_df = pd.DataFrame({
                    'id': range(1, len(df) + 1),
                    'email': df[column_mapping['email']],
                    'label': df[column_mapping['label']]
                })
                st.sidebar.success("✅ Found email and label columns, generated IDs")
            else:
                # Use first 3 columns as fallback
                processed_df = pd.DataFrame({
                    'id': df.iloc[:, 0],
                    'email': df.iloc[:, 1],
                    'label': df.iloc[:, 2]
                })
                st.sidebar.warning("⚠️ Using first 3 columns as id, email, label")

        else:
            st.error(f"❌ Cannot process file with only {len(df.columns)} column(s)")
            return pd.DataFrame()

        # Clean and validate the data
        if processed_df is not None:
            try:
                # Remove rows with missing values
                initial_count = len(processed_df)
                processed_df = processed_df.dropna()

                # Convert to string and remove very short messages
                processed_df['email'] = processed_df['email'].astype(str)
                processed_df['label'] = processed_df['label'].astype(str)
                processed_df = processed_df[processed_df['email'].str.len() > 5]  # At least 5 characters

                # Reset IDs to be sequential
                processed_df['id'] = range(1, len(processed_df) + 1)

                # Show cleaning results
                final_count = len(processed_df)
                if final_count < initial_count:
                    st.sidebar.info(f"🧹 Cleaned data: {initial_count} → {final_count} rows")

                # Show sample of processed data
                st.sidebar.write("**Processed sample:**")
                st.sidebar.dataframe(processed_df.head(3))

                # Show label distribution
                if len(processed_df) > 0:
                    label_counts = processed_df['label'].value_counts()
                    st.sidebar.write("**Label distribution:**")
                    st.sidebar.write(label_counts.to_dict())

                st.sidebar.success(f"✅ Successfully processed {len(processed_df)} emails")

                return processed_df

            except Exception as e:
                st.error(f"❌ Error processing dataframe: {str(e)}")
                return pd.DataFrame()

        return pd.DataFrame()

    # -------------------- Sidebar Configuration --------------------
    st.sidebar.markdown("## ⚙️ Configuration")
    BATCH_SIZE = st.sidebar.number_input("Batch size", min_value=1, max_value=100, value=10, help="Number of texts to process at once")
    CHUNK_SIZE = st.sidebar.number_input("Chunk size (tokens)", min_value=100, max_value=2000, value=1000, help="Maximum tokens per chunk")
    OVERLAP = st.sidebar.number_input("Chunk overlap (tokens)", min_value=0, max_value=500, value=200, help="Token overlap between chunks")
    DIM_REDUCTION = st.sidebar.selectbox("Dimensionality Reduction", ["PCA", "t-SNE", "UMAP"], help="Method for visualizing embeddings")
    VISUALIZATION_3D = st.sidebar.checkbox("3D Visualization", value=False, help="Enable 3D visualization for embeddings")

    st.sidebar.markdown("---")
    st.sidebar.markdown("## 📊 Model Info")
    EMBED_MODEL = "text-embedding-ada-002 (Simulated)"
    st.sidebar.info(f"**Model:** {EMBED_MODEL}")

    # -------------------- Utility Functions --------------------
    @st.cache_data
    def chunk_text(text: str, max_tokens: int, overlap: int) -> List[str]:
        """Split text into overlapping chunks"""
        words = text.split()
        chunks = []
        start = 0

        while start < len(words):
            end = min(len(words), start + max_tokens)
            chunk = words[start:end]
            chunks.append(" ".join(chunk))
            if end >= len(words):
                break
            start += max_tokens - overlap

        return chunks

    def simple_hash_embedding(text: str, dimensions: int = 384) -> np.ndarray:
        """Create a simple hash-based embedding for demonstration"""
        # Create a hash of the text
        text_hash = hashlib.md5(text.encode()).hexdigest()

        # Convert hash to numbers and create embedding
        embedding = []
        for i in range(0, len(text_hash), 2):
            hex_pair = text_hash[i:i+2]
            embedding.append(int(hex_pair, 16) / 255.0)

        # Pad or truncate to desired dimensions
        while len(embedding) < dimensions:
            embedding.extend(embedding[:min(len(embedding), dimensions - len(embedding))])

        embedding = np.array(embedding[:dimensions])

        # Add some randomness based on word count and common words
        words = text.lower().split()
        word_features = np.zeros(dimensions)

        for i, word in enumerate(words[:50]):  # Use first 50 words
            word_hash = abs(hash(word)) % dimensions
            word_features[word_hash] += 1.0 / (i + 1)

        # Combine hash and word features
        final_embedding = 0.7 * embedding + 0.3 * word_features

        # Normalize
        norm = np.linalg.norm(final_embedding)
        if norm > 0:
            final_embedding = final_embedding / norm

        return final_embedding

    def perform_dimensionality_reduction(embeddings_array, method="PCA", n_components=2):
        """Perform dimensionality reduction on embeddings"""
        if method == "PCA":
            from sklearn.decomposition import PCA
            reducer = PCA(n_components=n_components, random_state=42)
        elif method == "t-SNE":
            from sklearn.manifold import TSNE
            perplexity = min(30, len(embeddings_array) - 1)
            reducer = TSNE(n_components=n_components, random_state=42, perplexity=perplexity)
        elif method == "UMAP":
            try:
                import umap
                reducer = umap.UMAP(n_components=n_components, random_state=42)
            except ImportError:
                st.warning("⚠️ UMAP not installed, falling back to PCA")
                reducer = PCA(n_components=n_components, random_state=42)

        return reducer.fit_transform(embeddings_array)

    def create_smart_classifier(df_used, embeddings):
        """Create a smart classifier based on embeddings and labels from the actual used dataframe"""
        from sklearn.ensemble import RandomForestClassifier
        from sklearn.model_selection import train_test_split
        from sklearn.preprocessing import LabelEncoder

        # Ensure we're using the same dataframe that was used to create embeddings
        if len(df_used) != len(embeddings):
            raise ValueError(f"Mismatch: DataFrame has {len(df_used)} rows but {len(embeddings)} embeddings")

        # Encode labels
        label_encoder = LabelEncoder()
        encoded_labels = label_encoder.fit_transform(df_used['label'])

        # Split data for training
        X_train, X_test, y_train, y_test = train_test_split(
            embeddings, encoded_labels, test_size=0.2, random_state=42, stratify=encoded_labels
        )

        # Train classifier
        classifier = RandomForestClassifier(n_estimators=100, random_state=42)
        classifier.fit(X_train, y_train)

        return classifier, label_encoder, X_test, y_test

    def predict_email_class(email_text, classifier, label_encoder, get_email_vector_func):
        """Predict class for a single email"""
        # Get embedding for the email
        email_vector = get_email_vector_func(email_text).reshape(1, -1)

        # Predict
        prediction_encoded = classifier.predict(email_vector)[0]
        prediction_proba = classifier.predict_proba(email_vector)[0]

        # Decode prediction
        prediction_label = label_encoder.inverse_transform([prediction_encoded])[0]

        # Get confidence scores for all classes
        class_probabilities = {}
        for i, class_name in enumerate(label_encoder.classes_):
            class_probabilities[class_name] = prediction_proba[i]

        return prediction_label, class_probabilities

    def is_spam_classifier(email_text, df_used, embeddings, get_email_vector_func):
        """Special spam/ham binary classifier using the actual used dataframe"""
        # Ensure we're using the same dataframe that was used to create embeddings
        if len(df_used) != len(embeddings):
            raise ValueError(f"Mismatch: DataFrame has {len(df_used)} rows but {len(embeddings)} embeddings")

        # Create binary labels (spam vs non-spam)
        binary_labels = ['spam' if label.lower() == 'spam' else 'ham' for label in df_used['label']]

        from sklearn.ensemble import RandomForestClassifier
        from sklearn.model_selection import train_test_split
        from sklearn.preprocessing import LabelEncoder
        from sklearn.feature_extraction.text import TfidfVectorizer
        import re

        # Advanced spam detection features
        def extract_spam_features(text):
            features = []
            text_lower = text.lower()

            # 1. Capital letters ratio
            if len(text) > 0:
                caps_ratio = len([c for c in text if c.isupper()]) / len(text)
                features.append(caps_ratio)
            else:
                features.append(0)

            # 2. Exclamation marks count
            features.append(text.count('!'))

            # 3. Question marks count
            features.append(text.count('?'))

            # 4. Dollar signs count
            features.append(text.count('$'))

            # 5. URLs count
            url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
            urls = re.findall(url_pattern, text)
            features.append(len(urls))

            # 6. Email addresses count
            email_pattern = r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
            emails = re.findall(email_pattern, text)
            features.append(len(emails))

            # 7. Phone numbers count
            phone_pattern = r'[\+\(]?[1-9][0-9 .\-\(\)]{8,}[0-9]'
            phones = re.findall(phone_pattern, text)
            features.append(len(phones))

            # 8. Spam keywords score
            spam_keywords = [
                'win', 'winner', 'winning', 'won', 'prize', 'congratulations',
                'click here', 'click now', 'urgent', 'limited time', 'act now',
                'free', 'money', 'cash', 'discount', 'offer', 'deal', 'guarantee',
                'viagra', 'pills', 'medication', 'weight loss', 'lose weight',
                'million', 'thousand', 'hundred', 'dollars', 'pounds', 'euros',
                'dear customer', 'dear friend', 'dear sir', 'dear madam',
                'unsubscribe', 'remove', 'stop receiving', 'opt out',
                'call now', 'call today', 'don\'t hesitate', 'apply now',
                'special offer', 'limited offer', 'exclusive offer', 'best price',
                'no cost', 'no fees', 'no charges', 'no obligation',
                'order now', 'buy now', 'purchase now', 'shop now',
                'claim', 'redeem', 'collect', 'receive',
                'guaranteed', '100%', 'risk free', 'satisfaction',
                'increase', 'boost', 'enhance', 'improve',
                'credit', 'loan', 'debt', 'mortgage'
            ]

            spam_score = sum(1 for keyword in spam_keywords if keyword in text_lower)
            features.append(spam_score)

            # 9. Length of text
            features.append(len(text))

            # 10. Number of words
            features.append(len(text.split()))

            # 11. Average word length
            words = text.split()
            if words:
                avg_word_length = sum(len(word) for word in words) / len(words)
                features.append(avg_word_length)
            else:
                features.append(0)

            # 12. Numeric characters ratio
            if len(text) > 0:
                numeric_ratio = len([c for c in text if c.isdigit()]) / len(text)
                features.append(numeric_ratio)
            else:
                features.append(0)

            return np.array(features)

        # Extract additional features for all emails
        additional_features = []
        for email in df_used['email']:
            additional_features.append(extract_spam_features(email))
        additional_features = np.array(additional_features)

        # Extract features for the input email
        input_additional_features = extract_spam_features(email_text)

        # Combine embeddings with additional features
        combined_features = np.hstack([embeddings, additional_features])
        input_combined_features = np.hstack([
            get_email_vector_func(email_text),
            input_additional_features
        ])

        # Encode binary labels
        label_encoder = LabelEncoder()
        encoded_labels = label_encoder.fit_transform(binary_labels)

        # Train binary classifier with combined features
        X_train, X_test, y_train, y_test = train_test_split(
            combined_features, encoded_labels, test_size=0.2, random_state=42, stratify=encoded_labels
        )

        classifier = RandomForestClassifier(
            n_estimators=150,
            max_depth=20,
            min_samples_split=5,
            min_samples_leaf=2,
            random_state=42,
            class_weight='balanced'  # Handle imbalanced classes
        )
        classifier.fit(X_train, y_train)

        # Predict for input email
        input_combined_features = input_combined_features.reshape(1, -1)
        prediction_encoded = classifier.predict(input_combined_features)[0]
        prediction_proba = classifier.predict_proba(input_combined_features)[0]

        prediction_label = label_encoder.inverse_transform([prediction_encoded])[0]
        confidence = max(prediction_proba)

        # Additional rule-based check for obvious spam
        spam_score = extract_spam_features(email_text)[7]  # Get spam keywords score
        if spam_score >= 5 and prediction_label == 'ham':
            # Override if too many spam keywords
            prediction_label = 'spam'
            confidence = min(0.9, confidence + 0.2)

        return prediction_label, confidence, classifier.score(X_test, y_test)

    def plot_metrics_dashboard(y_true, y_pred, label_encoder):
        """Create comprehensive metrics dashboard"""
        # Get the unique labels from label encoder
        labels = label_encoder.classes_

        # Calculate metrics
        precision, recall, f1_score, support = precision_recall_fscore_support(
            y_true, y_pred, labels=np.arange(len(labels)), average=None
        )

        # Create metrics dataframe
        metrics_df = pd.DataFrame({
            'Label': labels,
            'Precision': precision,
            'Recall': recall,
            'F1-Score': f1_score,
            'Support': support
        })

        # Create visualizations with royal green color scheme
        col1, col2 = st.columns(2)

        with col1:
            # Metrics bar chart
            fig_metrics = go.Figure()

            fig_metrics.add_trace(go.Bar(
                name='Precision',
                x=labels,
                y=precision,
                marker_color='#228B22'
            ))

            fig_metrics.add_trace(go.Bar(
                name='Recall',
                x=labels,
                y=recall,
                marker_color='#005500'
            ))

            fig_metrics.add_trace(go.Bar(
                name='F1-Score',
                x=labels,
                y=f1_score,
                marker_color='#003A00'
            ))

            fig_metrics.update_layout(
                title='Classification Metrics by Label',
                xaxis_title='Labels',
                yaxis_title='Score',
                barmode='group',
                height=400,
                plot_bgcolor='#E8F5E8',
                paper_bgcolor='#E8F5E8'
            )

            st.plotly_chart(fig_metrics, use_container_width=True)

        with col2:
            # Support pie chart
            fig_support = px.pie(
                values=support,
                names=labels,
                title='Support Distribution (Number of Samples)',
                color_discrete_sequence=ROYAL_COLOR_SCALE[2:]
            )
            fig_support.update_layout(
                height=400,
                plot_bgcolor='#E8F5E8',
                paper_bgcolor='#E8F5E8'
            )
            st.plotly_chart(fig_support, use_container_width=True)

        # Metrics table
        st.markdown("#### 📊 Detailed Metrics Table")
        st.dataframe(metrics_df.round(4), use_container_width=True)

        # Overall metrics
        overall_precision = np.average(precision, weights=support)
        overall_recall = np.average(recall, weights=support)
        overall_f1 = np.average(f1_score, weights=support)

        col1, col2, col3 = st.columns(3)
        with col1:
            st.metric("Overall Precision", f"{overall_precision:.4f}")
        with col2:
            st.metric("Overall Recall", f"{overall_recall:.4f}")
        with col3:
            st.metric("Overall F1-Score", f"{overall_f1:.4f}")

        return metrics_df

    def plot_confusion_matrix(y_true, y_pred, label_encoder):
        """Plot confusion matrix"""
        # Get label names
        labels = label_encoder.classes_

        # Create confusion matrix
        cm = confusion_matrix(y_true, y_pred)

        # Create heatmap with royal green color scheme
        fig = go.Figure(data=go.Heatmap(
            z=cm,
            x=labels,
            y=labels,
            colorscale=[[0, '#E8F5E8'], [0.5, '#228B22'], [1, '#003A00']],
            text=cm,
            texttemplate='%{text}',
            textfont={"size": 12},
            hovertemplate='True: %{y}<br>Predicted: %{x}<br>Count: %{z}<extra></extra>'
        ))

        fig.update_layout(
            title='Confusion Matrix',
            xaxis_title='Predicted Label',
            yaxis_title='True Label',
            xaxis={'side': 'bottom'},
            width=600,
            height=500,
            plot_bgcolor='#E8F5E8',
            paper_bgcolor='#E8F5E8'
        )

        st.plotly_chart(fig, use_container_width=True)

    @st.cache_data
    def get_email_vector(email: str, chunk_size: int, overlap: int) -> np.ndarray:
        """Get embedding vector for email"""
        words = email.split()

        if len(words) > chunk_size:
            chunks = chunk_text(email, chunk_size, overlap)
            vectors = [simple_hash_embedding(chunk) for chunk in chunks]
            return np.mean(vectors, axis=0)
        else:
            return simple_hash_embedding(email)

    # -------------------- Main App --------------------
    st.markdown('<h1 class="main-header">👑 Email Classification & Embedding Explorer</h1>', unsafe_allow_html=True)

    # Auto-load sample data on first run
    if 'auto_data_loaded' not in st.session_state:
        st.session_state['auto_data_loaded'] = False

    if not st.session_state['auto_data_loaded']:
        try:
            sample_df, message = download_sample_data()
            if len(sample_df) > 0:  # Only set if we actually got data
                st.session_state['df'] = sample_df
                st.session_state['sample_data_message'] = message
            else:
                st.session_state['sample_data_message'] = "❌ No data loaded - please upload a CSV file"
            st.session_state['auto_data_loaded'] = True
        except Exception as e:
            st.session_state['sample_data_message'] = f"❌ Error loading sample data: {str(e)}"
            st.session_state['auto_data_loaded'] = True  # Prevent infinite retry

    # Show auto-load status
    if 'sample_data_message' in st.session_state:
        if "✅" in st.session_state['sample_data_message']:
            st.success(st.session_state['sample_data_message'])
        elif "⚠️" in st.session_state['sample_data_message']:
            st.warning(st.session_state['sample_data_message'])
        else:
            st.info(st.session_state['sample_data_message'])

    # File upload section
    col1, col2, col3 = st.columns([2, 1, 1])

    with col1:
        st.markdown("### 📁 Upload Additional Data")
        uploaded_file = st.file_uploader(
            "Upload your CSV file to add more data",
            type=['csv'],
            help="CSV should contain columns: id, email, label. This will be added to existing data."
        )

    with col2:
        st.markdown("### 🔄 Data Management")
        if st.button("🔄 Reload Sample Data"):
            # Force reload sample data
            st.session_state['auto_data_loaded'] = False
            try:
                sample_df, message = download_sample_data()
                if len(sample_df) > 0:
                    st.session_state['df'] = sample_df
                    st.session_state['auto_data_loaded'] = True
                    st.success("✅ Sample data reloaded!")
                else:
                    st.error("❌ Failed to reload sample data")
            except Exception as e:
                st.error(f"❌ Reload error: {str(e)}")
            st.rerun()

    with col3:
        st.markdown("### 🗑️ Reset")
        if st.button("🗑️ Clear All Data"):
            # Clear all data and reload sample
            for key in ['df', 'embeddings', 'email_data']:
                if key in st.session_state:
                    del st.session_state[key]
            st.session_state['auto_data_loaded'] = False
            st.success("✅ Data cleared!")
            st.rerun()

    # Handle file upload (append to existing data)
    df = None
    if uploaded_file is not None:
        try:
            new_df = pd.read_csv(uploaded_file)

            # Process and validate uploaded file
            processed_new_df = process_and_validate_dataframe(new_df)

            if len(processed_new_df) == 0:
                st.error("❌ Uploaded file contains no valid data after processing")
            else:
                # Get existing data
                existing_df = st.session_state.get('df', pd.DataFrame())

                if len(existing_df) > 0:
                    # Append new data to existing
                    # Adjust IDs to avoid conflicts
                    max_existing_id = existing_df['id'].max() if len(existing_df) > 0 else 0
                    processed_new_df['id'] = processed_new_df['id'] + max_existing_id

                    # Combine dataframes
                    combined_df = pd.concat([existing_df, processed_new_df], ignore_index=True)
                    st.session_state['df'] = combined_df
                    st.success(f"✅ Added {len(processed_new_df)} new emails! Total: {len(combined_df)} emails")
                else:
                    st.session_state['df'] = processed_new_df
                    st.success(f"✅ Loaded {len(processed_new_df)} emails!")

                # Clear embeddings since data changed
                if 'embeddings' in st.session_state:
                    del st.session_state['embeddings']
                if 'email_data' in st.session_state:
                    del st.session_state['email_data']

        except Exception as e:
            st.error(f"❌ Error loading file: {str(e)}")
            st.error("Please check that your file is a valid CSV")

    # Get current dataframe
    if 'df' in st.session_state:
        df = st.session_state['df']

    # Display data overview
    if df is not None and len(df) > 0:
        st.markdown("### 📊 Data Overview")

        col1, col2, col3, col4 = st.columns(4)

        with col1:
            st.markdown(f'<div class="metric-card"><h3>{len(df)}</h3><p>Total Emails</p></div>', unsafe_allow_html=True)

        with col2:
            st.markdown(f'<div class="metric-card"><h3>{df["label"].nunique()}</h3><p>Unique Labels</p></div>', unsafe_allow_html=True)

        with col3:
            avg_length = df["email"].str.split().str.len().mean()
            st.markdown(f'<div class="metric-card"><h3>{avg_length:.0f}</h3><p>Avg Words</p></div>', unsafe_allow_html=True)

        with col4:
            max_length = df["email"].str.split().str.len().max()
            st.markdown(f'<div class="metric-card"><h3>{max_length}</h3><p>Max Words</p></div>', unsafe_allow_html=True)

        # Label distribution
        st.markdown("### 🏷️ Label Distribution")
        label_counts = df["label"].value_counts()

        col1, col2 = st.columns([1, 1])

        with col1:
            fig_bar = px.bar(
                x=label_counts.index,
                y=label_counts.values,
                title="Email Count by Label",
                color=label_counts.values,
                color_continuous_scale=ROYAL_COLOR_SCALE
            )
            fig_bar.update_layout(
                showlegend=False,
                xaxis_title="Label",
                yaxis_title="Count",
                plot_bgcolor='#E8F5E8',
                paper_bgcolor='#E8F5E8'
            )
            st.plotly_chart(fig_bar, use_container_width=True)

        with col2:
            fig_pie = px.pie(
                values=label_counts.values,
                names=label_counts.index,
                title="Label Distribution",
                color_discrete_sequence=ROYAL_COLOR_SCALE[2:]
            )
            fig_pie.update_layout(
                plot_bgcolor='#E8F5E8',
                paper_bgcolor='#E8F5E8'
            )
            st.plotly_chart(fig_pie, use_container_width=True)

        # Email length analysis
        st.markdown("### 📏 Email Length Analysis")
        lengths = df["email"].str.split().str.len()

        fig_hist = px.histogram(
            x=lengths,
            nbins=20,
            title="Distribution of Email Lengths (Words)",
            color_discrete_sequence=["#005500"]
        )
        fig_hist.update_layout(
            xaxis_title="Number of Words",
            yaxis_title="Frequency",
            plot_bgcolor='#E8F5E8',
            paper_bgcolor='#E8F5E8'
        )
        st.plotly_chart(fig_hist, use_container_width=True)

        # Build embeddings
        st.markdown("### 🧠 Build Embeddings")

        # Add sampling option for large datasets
        if len(df) > 1000:
            st.warning(f"⚠️ Large dataset detected ({len(df)} emails). Consider sampling for faster processing.")

            col_sample1, col_sample2 = st.columns(2)
            with col_sample1:
                use_sampling = st.checkbox("Enable sampling for demo", value=False)
            with col_sample2:
                if use_sampling:
                    sample_size = st.slider("Sample size", min_value=100, max_value=len(df), value=min(len(df), 5000))
        else:
            use_sampling = False
            sample_size = len(df)

        col1, col2 = st.columns([1, 1])

        with col1:
            if st.button("🚀 Build Embeddings", type="primary"):
                # Prepare data
                if use_sampling and len(df) > sample_size:
                    # Stratified sampling to maintain label distribution
                    df_sample = df.groupby('label', group_keys=False).apply(
                        lambda x: x.sample(min(len(x), sample_size // df['label'].nunique()))
                    ).reset_index(drop=True)
                    st.info(f"🎯 Using stratified sample: {len(df_sample)} emails from {len(df)} total")
                else:
                    df_sample = df.copy()

                with st.spinner("Building embeddings..."):
                    # Create progress containers
                    progress_bar = st.progress(0)
                    progress_text = st.empty()
                    status_text = st.empty()

                    embeddings = []
                    total_emails = len(df_sample)
                    batch_size = BATCH_SIZE

                    status_text.info(f"🔄 Processing {total_emails} emails in batches of {batch_size}...")

                    # Process in batches for better performance
                    for batch_start in range(0, total_emails, batch_size):
                        batch_end = min(batch_start + batch_size, total_emails)
                        batch_df = df_sample.iloc[batch_start:batch_end]

                        # Process batch
                        batch_embeddings = []
                        for i, row in batch_df.iterrows():
                            email_vector = get_email_vector(row["email"], CHUNK_SIZE, OVERLAP)
                            batch_embeddings.append(email_vector)

                        embeddings.extend(batch_embeddings)

                        # Update progress
                        progress_value = len(embeddings) / total_emails
                        progress_bar.progress(progress_value)
                        progress_text.text(f"Processed {len(embeddings)} of {total_emails} emails ({progress_value:.1%})")

                    # Store results
                    st.session_state['embeddings'] = np.array(embeddings)
                    st.session_state['email_data'] = df_sample.copy()

                    # Clean up UI
                    progress_bar.empty()
                    progress_text.empty()
                    status_text.empty()

                    st.success(f"✅ Built embeddings for {len(embeddings)} emails!")

                    # Show embedding stats
                    embedding_shape = st.session_state['embeddings'].shape
                    st.info(f"📊 Embedding shape: {embedding_shape[0]} emails × {embedding_shape[1]} dimensions")

        with col2:
            if 'embeddings' in st.session_state:
                st.info(f"📊 Embeddings ready: {st.session_state['embeddings'].shape}")

        # Similarity search and email classification
        if 'embeddings' in st.session_state:
            st.markdown("### 🔍 Email Analysis & Prediction")

            # Create tabs for different analysis types
            tab1, tab2, tab3 = st.tabs(["🔍 Similarity Search", "🎯 Email Classification", "🚨 Spam Detection"])

            with tab1:
                st.markdown("#### Find Similar Emails")
                query_text = st.text_area(
                    "Enter email text to find similar emails:",
                    placeholder="Type your email content here...",
                    height=100,
                    key="similarity_search"
                )

                col1, col2 = st.columns([1, 3])

                with col1:
                    top_k = st.slider("Number of results", min_value=1, max_value=10, value=5, key="similarity_k")

                with col2:
                    if st.button("🔍 Find Similar Emails") and query_text:
                        query_vector = get_email_vector(query_text, CHUNK_SIZE, OVERLAP).reshape(1, -1)
                        similarities = cosine_similarity(query_vector, st.session_state['embeddings'])[0]

                        # Get top k most similar
                        top_indices = np.argsort(similarities)[::-1][:top_k]

                        st.markdown("#### 📋 Search Results")

                        for i, idx in enumerate(top_indices):
                            email_data = st.session_state['email_data'].iloc[idx]
                            similarity_score = similarities[idx]

                            st.markdown(f"""
                            <div class="email-result">
                                <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px;">
                                    <div>
                                        <strong>#{i+1} - ID: {email_data['id']}</strong> |
                                        <strong>Label:</strong> {email_data['label']}
                                    </div>
                                    <span class="similarity-score">Similarity: {similarity_score:.4f}</span>
                                </div>
                                <div style="background: white; padding: 10px; border-radius: 5px; font-size: 14px;">
                                    {email_data['email']}
                                </div>
                            </div>
                            """, unsafe_allow_html=True)

            with tab2:
                st.markdown("#### Multi-Class Email Classification")
                classify_text = st.text_area(
                    "Enter email text for classification:",
                    placeholder="Type the email content you want to classify...",
                    height=100,
                    key="classification_text"
                )

                if st.button("🎯 Classify Email") and classify_text:
                    try:
                        # Train classifier using the same dataframe used for embeddings
                        with st.spinner("Training classifier..."):
                            classifier, label_encoder, X_test, y_test = create_smart_classifier(
                                st.session_state['email_data'],
                                st.session_state['embeddings']
                            )

                        # Make prediction - pass the modified get_email_vector function
                        get_email_vector_lambda = lambda email: get_email_vector(email, CHUNK_SIZE, OVERLAP)
                        prediction_label, class_probabilities = predict_email_class(
                            classify_text, classifier, label_encoder, get_email_vector_lambda
                        )

                        # Display results
                        col1, col2 = st.columns([1, 2])

                        with col1:
                            st.markdown("#### 🎯 Prediction Result")
                            confidence = max(class_probabilities.values())

                            if confidence > 0.7:
                                st.success(f"**Predicted Class:** {prediction_label}")
                                st.metric("Confidence", f"{confidence:.2%}")
                            elif confidence > 0.5:
                                st.warning(f"**Predicted Class:** {prediction_label}")
                                st.metric("Confidence", f"{confidence:.2%}")
                            else:
                                st.error(f"**Predicted Class:** {prediction_label}")
                                st.metric("Confidence (Low)", f"{confidence:.2%}")

                        with col2:
                            st.markdown("#### 📊 Class Probabilities")

                            # Create probability dataframe
                            prob_df = pd.DataFrame([
                                {'Class': class_name, 'Probability': prob}
                                for class_name, prob in class_probabilities.items()
                            ]).sort_values('Probability', ascending=False)

                            # Plot probabilities
                            fig_prob = px.bar(
                                prob_df,
                                x='Class',
                                y='Probability',
                                title='Classification Probabilities',
                                color='Probability',
                                color_continuous_scale=ROYAL_COLOR_SCALE
                            )
                            fig_prob.update_layout(
                                height=300,
                                plot_bgcolor='#E8F5E8',
                                paper_bgcolor='#E8F5E8'
                            )
                            st.plotly_chart(fig_prob, use_container_width=True)

                            # Show probability table
                            st.dataframe(prob_df, use_container_width=True, hide_index=True)

                        # Model performance
                        test_accuracy = classifier.score(X_test, y_test)
                        st.info(f"📈 Model Test Accuracy: {test_accuracy:.2%}")

                    except Exception as e:
                        st.error(f"❌ Classification error: {str(e)}")

            with tab3:
                st.markdown("#### 🚨 Spam/Ham Detection")
                spam_text = st.text_area(
                    "Enter email text for spam detection:",
                    placeholder="Type the email content to check if it's spam or legitimate...",
                    height=100,
                    key="spam_detection_text"
                )

                if st.button("🚨 Check for Spam") and spam_text:
                    try:
                        with st.spinner("Analyzing email for spam..."):
                            get_email_vector_lambda = lambda email: get_email_vector(email, CHUNK_SIZE, OVERLAP)
                            prediction, confidence, model_accuracy = is_spam_classifier(
                                spam_text,
                                st.session_state['email_data'],
                                st.session_state['embeddings'],
                                get_email_vector_lambda
                            )

                        # Display spam detection results
                        col1, col2, col3 = st.columns(3)

                        with col1:
                            if prediction == 'spam':
                                st.error(f"🚨 **SPAM DETECTED**")
                                st.markdown("⚠️ This email appears to be spam")
                            else:
                                st.success(f"✅ **LEGITIMATE EMAIL**")
                                st.markdown("📧 This email appears to be legitimate")

                        with col2:
                            st.metric("Detection Confidence", f"{confidence:.2%}")

                            # Confidence interpretation
                            if confidence > 0.8:
                                st.success("High confidence")
                            elif confidence > 0.6:
                                st.warning("Medium confidence")
                            else:
                                st.error("Low confidence")

                        with col3:
                            st.metric("Model Accuracy", f"{model_accuracy:.2%}")
                            st.info("Trained on your dataset")

                        # Additional analysis
                        st.markdown("#### 🔍 Email Analysis Details")

                        # Extract features for analysis
                        email_lower = spam_text.lower()
                        spam_indicators = []

                        # Common spam indicators
                        spam_words = ['win', 'prize', 'congratulations', 'click here', 'urgent', 'limited time',
                                     'free', 'money', 'cash', 'discount', 'offer', 'deal', 'guarantee']

                        found_indicators = [word for word in spam_words if word in email_lower]
                        if found_indicators:
                            spam_indicators.extend(found_indicators)

                        # Check for suspicious patterns
                        if len([c for c in spam_text if c.isupper()]) > len(spam_text) * 0.3:
                            spam_indicators.append("Excessive capital letters")

                        if spam_text.count('!') > 3:
                            spam_indicators.append("Multiple exclamation marks")

                        if spam_indicators:
                            st.warning("⚠️ Potential spam indicators found:")
                            for indicator in spam_indicators:
                                st.write(f"• {indicator}")
                        else:
                            st.success("✅ No obvious spam indicators detected")

                    except Exception as e:
                        st.error(f"❌ Spam detection error: {str(e)}")

            # Visualization
            st.markdown("### 📈 Embedding Visualization")

            col1, col2 = st.columns([1, 1])

            with col1:
                if st.button("🎨 Generate 2D Visualization"):
                    with st.spinner("Generating 2D visualization..."):
                        embeddings = st.session_state['embeddings']

                        reduced_embeddings = perform_dimensionality_reduction(
                            embeddings, method=DIM_REDUCTION, n_components=2
                        )

                        # Create visualization dataframe
                        viz_df = pd.DataFrame({
                            'x': reduced_embeddings[:, 0],
                            'y': reduced_embeddings[:, 1],
                            'label': st.session_state['email_data']['label'],
                            'id': st.session_state['email_data']['id'],
                            'email_preview': st.session_state['email_data']['email'].str[:100] + "..."
                        })

                        fig_scatter = px.scatter(
                            viz_df,
                            x='x',
                            y='y',
                            color='label',
                            hover_data=['id', 'email_preview'],
                            title=f"Email Embeddings 2D Visualization ({DIM_REDUCTION})",
                            width=800,
                            height=600,
                            color_discrete_sequence=ROYAL_COLOR_SCALE[2:]
                        )

                        fig_scatter.update_traces(marker=dict(size=8, opacity=0.7))
                        fig_scatter.update_layout(
                            xaxis_title=f"{DIM_REDUCTION} Component 1",
                            yaxis_title=f"{DIM_REDUCTION} Component 2",
                            plot_bgcolor='#E8F5E8',
                            paper_bgcolor='#E8F5E8'
                        )

                        st.plotly_chart(fig_scatter, use_container_width=True)

            with col2:
                if st.button("🌐 Generate 3D Visualization"):
                    with st.spinner("Generating 3D visualization..."):
                        embeddings = st.session_state['embeddings']

                        reduced_embeddings = perform_dimensionality_reduction(
                            embeddings, method=DIM_REDUCTION, n_components=3
                        )

                        # Create 3D visualization dataframe
                        viz_df_3d = pd.DataFrame({
                            'x': reduced_embeddings[:, 0],
                            'y': reduced_embeddings[:, 1],
                            'z': reduced_embeddings[:, 2],
                            'label': st.session_state['email_data']['label'],
                            'id': st.session_state['email_data']['id'],
                            'email_preview': st.session_state['email_data']['email'].str[:100] + "..."
                        })

                        # Create 3D scatter plot
                        fig_3d = px.scatter_3d(
                            viz_df_3d,
                            x='x',
                            y='y',
                            z='z',
                            color='label',
                            hover_data=['id', 'email_preview'],
                            title=f"Email Embeddings 3D Visualization ({DIM_REDUCTION})",
                            width=800,
                            height=700,
                            color_discrete_sequence=ROYAL_COLOR_SCALE[2:]
                        )

                        fig_3d.update_traces(marker=dict(size=5, opacity=0.8))
                        fig_3d.update_layout(
                            scene=dict(
                                xaxis_title=f"{DIM_REDUCTION} Component 1",
                                yaxis_title=f"{DIM_REDUCTION} Component 2",
                                zaxis_title=f"{DIM_REDUCTION} Component 3",
                                camera=dict(
                                    eye=dict(x=1.5, y=1.5, z=1.5)
                                ),
                                bgcolor='#E8F5E8'
                            ),
                            paper_bgcolor='#E8F5E8'
                        )

                        st.plotly_chart(fig_3d, use_container_width=True)

            # Classification Metrics and Analysis
            st.markdown("### 📊 Classification Metrics Analysis")

            if st.button("📈 Generate Classification Report"):
                with st.spinner("Generating classification metrics..."):
                    # Train classifier for evaluation using the same dataframe used for embeddings
                    classifier, label_encoder, X_test, y_test = create_smart_classifier(
                        st.session_state['email_data'], st.session_state['embeddings']
                    )

                    # Get predictions on test set
                    y_pred = classifier.predict(X_test)
                    y_true = y_test

                    st.markdown("#### 🎯 Classification Performance")
                    st.info(f"📝 Results based on trained Random Forest classifier with 80/20 train-test split on {len(st.session_state['email_data'])} emails.")

                    # Plot metrics dashboard
                    metrics_df = plot_metrics_dashboard(y_true, y_pred, label_encoder)

                    # Plot confusion matrix
                    st.markdown("#### 🔄 Confusion Matrix")
                    plot_confusion_matrix(y_true, y_pred, label_encoder)

                    # Classification report
                    st.markdown("#### 📋 Detailed Classification Report")
                    # Convert encoded labels back for sklearn report
                    y_true_labels = label_encoder.inverse_transform(y_true)
                    y_pred_labels = label_encoder.inverse_transform(y_pred)

                    report = classification_report(y_true_labels, y_pred_labels, output_dict=True)
                    report_df = pd.DataFrame(report).transpose()
                    st.dataframe(report_df.round(4), use_container_width=True)

                    # Model performance summary
                    test_accuracy = classifier.score(X_test, y_test)
                    st.success(f"🎯 Overall Test Accuracy: {test_accuracy:.2%}")

            # Vector Database Visualization
            st.markdown("### 🗃️ Vector Database Analysis")

            if st.button("🔍 Analyze Vector Database"):
                with st.spinner("Analyzing vector database..."):
                    embeddings = st.session_state['embeddings']

                    # Calculate similarity matrix
                    similarity_matrix = cosine_similarity(embeddings)

                    col1, col2 = st.columns(2)

                    with col1:
                        # Similarity heatmap
                        fig_sim = px.imshow(
                            similarity_matrix[:20, :20],  # Show first 20x20 for readability
                            title="Cosine Similarity Matrix (First 20 emails)",
                            color_continuous_scale=[[0, '#E8F5E8'], [0.5, '#228B22'], [1, '#003A00']],
                            aspect="auto"
                        )
                        fig_sim.update_layout(
                            height=400,
                            plot_bgcolor='#E8F5E8',
                            paper_bgcolor='#E8F5E8'
                        )
                        st.plotly_chart(fig_sim, use_container_width=True)

                    with col2:
                        # Distribution of similarities
                        upper_triangle = similarity_matrix[np.triu_indices_from(similarity_matrix, k=1)]

                        fig_dist = px.histogram(
                            x=upper_triangle,
                            nbins=50,
                            title="Distribution of Cosine Similarities",
                            labels={'x': 'Cosine Similarity', 'y': 'Frequency'},
                            color_discrete_sequence=['#005500']
                        )
                        fig_dist.update_layout(
                            height=400,
                            plot_bgcolor='#E8F5E8',
                            paper_bgcolor='#E8F5E8'
                        )
                        st.plotly_chart(fig_dist, use_container_width=True)

                    # Vector statistics
                    st.markdown("#### 📊 Vector Database Statistics")

                    col1, col2, col3, col4 = st.columns(4)

                    with col1:
                        avg_sim = np.mean(upper_triangle)
                        st.metric("Average Similarity", f"{avg_sim:.4f}")

                    with col2:
                        std_sim = np.std(upper_triangle)
                        st.metric("Similarity Std Dev", f"{std_sim:.4f}")

                    with col3:
                        max_sim = np.max(upper_triangle)
                        st.metric("Max Similarity", f"{max_sim:.4f}")

                    with col4:
                        min_sim = np.min(upper_triangle)
                        st.metric("Min Similarity", f"{min_sim:.4f}")

                    # Clustering analysis
                    st.markdown("#### 🎯 Clustering Analysis")

                    try:
                        from sklearn.cluster import KMeans

                        # Perform k-means clustering
                        n_clusters = min(5, st.session_state['email_data']['label'].nunique())
                        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
                        cluster_labels = kmeans.fit_predict(embeddings)

                        # Create cluster visualization
                        reduced_for_clustering = perform_dimensionality_reduction(
                            embeddings, method="PCA", n_components=2
                        )

                        cluster_df = pd.DataFrame({
                            'x': reduced_for_clustering[:, 0],
                            'y': reduced_for_clustering[:, 1],
                            'cluster': cluster_labels,
                            'true_label': st.session_state['email_data']['label'],
                            'id': st.session_state['email_data']['id']
                        })

                        fig_cluster = px.scatter(
                            cluster_df,
                            x='x',
                            y='y',
                            color='cluster',
                            symbol='true_label',
                            hover_data=['id'],
                            title="K-Means Clustering of Email Embeddings",
                            color_continuous_scale=ROYAL_COLOR_SCALE
                        )

                        fig_cluster.update_traces(marker=dict(size=8, opacity=0.7))
                        fig_cluster.update_layout(
                            plot_bgcolor='#E8F5E8',
                            paper_bgcolor='#E8F5E8'
                        )
                        st.plotly_chart(fig_cluster, use_container_width=True)

                    except ImportError:
                        st.warning("⚠️ Scikit-learn clustering not available")

        # Data preview
        if st.checkbox("👀 Show Raw Data"):
            st.markdown("### 📋 Raw Data Preview")
            st.dataframe(df, use_container_width=True)

    else:
        # Show information when no data is loaded
        st.warning("⚠️ No data loaded. Please upload a CSV file to get started.")

        # Show expected format examples
        st.markdown("### 📝 Supported CSV Formats")

        col1, col2 = st.columns(2)

        with col1:
            st.markdown("**Format 1: Category/Message**")
            st.code("""
Category,Message
spam,"You won $1000! Click here to claim"
ham,"Hello team, meeting at 2PM tomorrow"
spam,"URGENT! Your account will be suspended"
ham,"Thanks for your email, I'll reply soon"
            """)

        with col2:
            st.markdown("**Format 2: Standard 3-column**")
            st.code("""
id,email,label
1,"You won $1000! Click here",spam
2,"Hello team, meeting at 2PM",ham
3,"URGENT! Account suspended",spam
4,"Thanks for your email",ham
            """)

    # Footer
    st.markdown("---")
    st.markdown(
        """
        <div style='text-align: center; color: #003A00; padding: 20px;'>
            👑 Email Classification & Embedding Explorer |
            Built with Streamlit & Python |
            🚀 Powered by Machine Learning
        </div>
        """,
        unsafe_allow_html=True
    )

# Run the main function
if __name__ == "__main__":
    main()

Writing app.py


In [3]:
!curl https://loca.lt/mytunnelpassword

34.148.227.158

In [None]:
!streamlit run app.py & npx localtunnel --port 8501

[1G[0K⠙
Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[1G[0K⠹[1G[0K⠸[1G[0K⠼[1G[0K⠴[1G[0K⠦[1G[0K⠧[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.28.0.12:8501[0m
[34m  External URL: [0m[1mhttp://34.148.227.158:8501[0m
[0m
[1G[0K⠇[1G[0K⠏[1G[0K⠋[1G[0K⠙[1G[0K⠹[1G[0K⠸[1G[0K⠼[1G[0K⠴[1G[0K[1G[0JNeed to install the following packages:
localtunnel@2.0.2
Ok to proceed? (y) [20Gy

[1G[0K⠙[1G[0K⠹[1G[0K⠸[1G[0K⠼[1G[0K⠴[1G[0K⠦[1G[0K⠧[1G[0K⠇[1G[0K⠏[1G[0K⠋[1G[0K⠙[1G[0K⠹[1G[0K⠸[1G[0K⠼[1G[0K⠴[1G[0K⠦[1G[0K⠧[1G[0K⠇[1G[0K⠏[1G[0K⠋[1G[0K⠙[1G[0K⠹[1G[0K⠸[1G[0K⠼[1G[0K⠴[1G[0K⠦[1G[0K⠧[1G[0K⠇[1G[0K⠏[1G[0K⠋[1G[0K⠙[1G[0K⠹[1G[0K⠸[1G[0K⠼[1G[0K⠴[1G[0K⠦[1G[0K⠧[1G[0Kyour url is: https://petite-mice-chew.loca.lt
Websocket message size limit exceeded