<a href="https://colab.research.google.com/github/Karthikreddy1010/Automated-Scientific-Data-Paper-Linkage-with-Contextual-Summarization/blob/main/Topic_modelling_databasesystems.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install bertopic
!pip install umap-learn
!pip install hdbscan
!pip install sentence-transformers
!pip install plotly
!pip install wordcloud
!pip install gensim
!pip install keybert
!pip install tqdm
!pip install matplotlib
!pip install seaborn

In [None]:
# For CTM support
!pip install contextualized-topic-models

In [None]:
import logging
import pandas as pd
import numpy as np
import os
from datetime import datetime, timezone
import torch
from bertopic import BERTopic
from umap import UMAP
from hdbscan import HDBSCAN
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import silhouette_score
import pickle
import re
import json
from collections import Counter
import gensim
from gensim import corpora
from gensim.models.coherencemodel import CoherenceModel
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from wordcloud import WordCloud

warnings.filterwarnings('ignore')

print("All imports completed successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Pandas version: {pd.__version__}")



class DomainConfig:
    # Paths
    PROCESSED_TEXT_CSV = "updated.csv"
    OUTPUT_DIR = "domain_modeling_results"

    # Embedding model
    EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"

    # UMAP parameters - OPTIMIZED based on 1st pipeline success
    UMAP_PARAMS = {
        'n_neighbors': 8,           # Balanced for local/global structure
        'n_components': 3,          # Reduced for better clustering (was 5)
        'min_dist': 0.05,           # Balanced separation (was 0.01)
        'metric': 'cosine',
        'random_state': 42,
        'low_memory': False
    }

    # HDBSCAN parameters - OPTIMIZED for better topics
    HDBSCAN_PARAMS = {
        'min_cluster_size': 8,      # Increased for better topics (was 6)
        'min_samples': 3,           # Increased for stability (was 2)
        'cluster_selection_epsilon': 0.02,
        'metric': 'euclidean',
        'cluster_selection_method': 'eom',  # Better than 'leaf'
        'prediction_data': True
    }

    # BERTopic settings - OPTIMIZED based on 1st pipeline success
    BERTOPIC_SETTINGS = {
        'top_n_words': 12,
        'n_gram_range': (1, 2),     # Bigrams better than trigrams
        'min_topic_size': 8,        # Increased for better topics (was 6)
        'calculate_probabilities': True,
        'verbose': False,
        'nr_topics': 20             # Set target instead of 'auto'
    }

    # Processing settings
    MIN_DOC_LENGTH = 50
    MAX_DOC_LENGTH = 2000
    BATCH_SIZE = 32

    def __init__(self):
        os.makedirs(self.OUTPUT_DIR, exist_ok=True)

config = DomainConfig()
print(f" Configuration initialized")



class DataProcessor:
    def __init__(self):
        # Scientific stop words - preserve domain terminology
        self.scientific_stop_words = set([
            'paper', 'study', 'research', 'result', 'method', 'approach',
            'show', 'demonstrate', 'present', 'investigate', 'analyze',
            'discuss', 'conclude', 'suggest', 'indicate', 'figure', 'table',
            'author', 'journal', 'publication', 'reference', 'citation',
            'section', 'abstract', 'introduction', 'background', 'conclusion'
        ])

    def clean_text(self, text):
        """Clean text while preserving scientific content"""
        if not isinstance(text, str):
            return ""

        text = text.lower().strip()

        if len(text) < config.MIN_DOC_LENGTH:
            return ""

        if len(text) > config.MAX_DOC_LENGTH:
            text = text[:config.MAX_DOC_LENGTH]

        # Remove URLs, citations, and metadata
        patterns_to_remove = [
            r'http\S+|www\S+|https\S+',
            r'doi:\s*\S+',
            r'\S*@\S*\s?',
            r'copyright\s+\d{4}',
            r'all rights reserved',
            r'received|accepted|submitted',
            r'creative commons',
            r'peer review'
        ]

        for pattern in patterns_to_remove:
            text = re.sub(pattern, ' ', text)

        # Clean up whitespace
        text = re.sub(r'\s+', ' ', text)

        # Filter words - preserve scientific terms
        words = text.split()
        cleaned_words = []

        for word in words:
            if (len(word) > 2 and
                len(word) < 25 and
                word not in self.scientific_stop_words and
                not word.isdigit()):
                cleaned_words.append(word)

        cleaned_text = ' '.join(cleaned_words).strip()
        return cleaned_text if len(cleaned_text) >= config.MIN_DOC_LENGTH else ""

    def load_and_process_data(self):
        """Load and process data"""
        try:
            print(" Loading and processing data...")

            df = pd.read_csv(config.PROCESSED_TEXT_CSV)
            print(f" Loaded {len(df)} documents")

            # Apply cleaning
            df['cleaned_text'] = df['processed_text'].apply(self.clean_text)

            # Remove empty documents
            initial_count = len(df)
            df = df[df['cleaned_text'].str.len() > config.MIN_DOC_LENGTH].copy()
            final_count = len(df)

            if initial_count != final_count:
                print(f" Removed {initial_count - final_count} documents after cleaning")

            # Text statistics
            text_lengths = df['cleaned_text'].str.len()
            word_counts = df['cleaned_text'].str.split().str.len()

            print(f" Final dataset: {final_count} documents")
            print(f" Average length: {text_lengths.mean():.1f} chars, {word_counts.mean():.1f} words")

            return df

        except Exception as e:
            print(f" Data loading failed: {e}")
            raise



class EnhancedDomainClassifier:
    def __init__(self):
        # Comprehensive domain keywords with FIXED coverage
        self.domain_keywords = {
            'biology': [
                'cell', 'gene', 'protein', 'dna', 'genetic', 'molecular', 'organism',
                'evolution', 'genome', 'species', 'ecological', 'biodiversity',
                'microbial', 'enzyme', 'metabolism', 'phylogenetic', 'transcription',
                'rna', 'chromosome', 'mitochondria', 'apoptosis', 'sequence',
                'mutation', 'expression', 'cellular', 'developmental', 'plant',
                'animal', 'bacterial', 'viral', 'evolutionary', 'population',
                'biological', 'physiology', 'genomic', 'proteomic', 'transcriptomic',
                'microbiome', 'neuroscience', 'zoology', 'botany', 'ecology',
                'soybean', 'fruit', 'circrnas', 'drosophila', 'nutrient', 'invasive',
                'breast', 'strain', 'speech', 'brain', 'neural', 'jet', 'physics',
                'collision', 'proton', 'oxygen', 'sex', 'temperature', 'thermal'
            ],
            'medicine': [
                'patient', 'clinical', 'treatment', 'disease', 'medical', 'therapy',
                'health', 'drug', 'vaccine', 'diagnosis', 'symptom', 'hospital',
                'pharmaceutical', 'epidemiology', 'pathology', 'oncology', 'immunology',
                'surgery', 'prognosis', 'biomarker', 'clinical trial', 'pharmacology',
                'therapeutic', 'dosage', 'recovery', 'mortality', 'morbidity', 'cancer',
                'tumor', 'infection', 'inflammatory', 'neurological', 'cardiology',
                'pediatric', 'geriatric', 'psychiatry', 'radiology', 'anesthesia',
                'public health', 'virology', 'bacteriology', 'dose', 'imaging',
                'muscle', 'care', 'inflammation', 'mental', 'virus', 'sars', 'vector',
                'mouse', 'dam', 'like', 'wash', 'londrina'
            ],
            'chemistry': [
                'molecule', 'reaction', 'compound', 'chemical', 'synthesis',
                'catalyst', 'polymer', 'organic', 'inorganic', 'spectroscopy',
                'chromatography', 'crystallography', 'stoichiometry', 'kinetics',
                'nmr', 'mass spectrometry', 'electrochemistry', 'photochemistry',
                'reagent', 'solvent', 'yield', 'purification', 'characterization',
                'crystal', 'bond', 'structure', 'atomic', 'molecular', 'analytical',
                'physical chemistry', 'quantum chemistry', 'medicinal chemistry',
                'biochemistry', 'polymer chemistry', 'materials chemistry',
                'irrigation', 'management', 'soil', 'water'
            ],
            'environmental_science': [
                'environmental', 'climate', 'ecosystem', 'sustainability', 'pollution',
                'conservation', 'biodiversity', 'ecological', 'environmental impact',
                'climate change', 'environmental management', 'sustainable development',
                'habitat', 'wildlife', 'conservation', 'environmental policy', 'earth',
                'geological', 'ocean', 'atmospheric', 'marine', 'terrestrial',
                'agricultural', 'forestry', 'water resources', 'air quality', 'soil science',
                'environmental engineering', 'conservation biology', 'environmental health',
                'ice', 'university', 'usa', 'fish', 'population', 'ecology',
                'geological survey', 'santa', 'survey', 'rupture'
            ],
            'computer_science': [
                'algorithm', 'software', 'programming', 'machine learning',
                'neural network', 'data', 'system', 'computational', 'artificial intelligence',
                'database', 'network', 'optimization', 'cybersecurity', 'blockchain',
                'deep learning', 'computer vision', 'natural language processing',
                'computation', 'modeling', 'simulation', 'data analysis', 'big data',
                'cloud computing', 'internet of things', 'robotics', 'automation'
            ],
            'physics': [
                'quantum', 'particle', 'energy', 'field', 'mechanics', 'astrophysics',
                'relativity', 'nuclear', 'optics', 'thermodynamics', 'electromagnetic',
                'condensed matter', 'cosmology', 'entanglement', 'superconductivity',
                'atomic', 'molecular', 'theoretical', 'experimental', 'quantum mechanics',
                'statistical mechanics', 'fluid dynamics', 'plasma physics', 'optics',
                'astronomy', 'cosmology', 'particle physics', 'solid state physics'
            ],
            'engineering': [
                'design', 'system', 'manufacturing', 'structural', 'electrical',
                'mechanical', 'control', 'sensor', 'robotics', 'automation',
                'aerospace', 'civil', 'materials', 'nanotechnology', 'mechatronics',
                'biomedical', 'chemical engineering', 'environmental engineering'
            ],
            'materials_science': [
                'material', 'composite', 'polymer', 'ceramic', 'metal', 'alloy',
                'nanomaterial', 'crystal', 'structure', 'properties', 'synthesis',
                'fabrication', 'characterization', 'mechanical properties', 'thermal properties',
                'electronic properties', 'optical properties', 'material design',
                'heat', 'strength', 'stimulus', 'group'
            ],
            'psychology': [
                'behavior', 'cognitive', 'psychological', 'mental', 'personality',
                'emotion', 'memory', 'neural', 'brain', 'perception', 'learning',
                'cognition', 'developmental', 'social psychology', 'clinical psychology',
                'behavioral', 'psychiatric', 'neuroscience', 'cognitive science',
                'speech', 'elife', 'neuroscience', 'university'
            ]
        }

        # Domain relationships for ambiguous cases
        self.domain_relationships = {
            'ecology': ['biology', 'environmental_science'],
            'neuroscience': ['biology', 'psychology', 'medicine'],
            'biochemistry': ['biology', 'chemistry'],
            'bioinformatics': ['biology', 'computer_science'],
            'materials_science': ['chemistry', 'engineering', 'physics'],
            'environmental_health': ['environmental_science', 'medicine']
        }

    def _enhanced_domain_scoring(self, topic_words, topic_text):
        """FIXED domain scoring - prevents division by zero"""
        domain_scores = {}

        for domain, keywords in self.domain_keywords.items():
            score = 0.1  # CRITICAL FIX: Base score to prevent division by zero
            matches = []

            # Position-based scoring (words in top positions are more important)
            for i, word in enumerate(topic_words[:8]):
                if word in keywords:
                    # Position bonus: higher score for words appearing earlier
                    position_weight = max(0, (8 - i) / 8) * 2.0
                    score += 1.0 + position_weight
                    matches.append((word, i))

            # Contextual presence scoring
            for keyword in keywords:
                if keyword in topic_text and keyword not in [m[0] for m in matches]:
                    if ' ' in keyword:  # Multi-word terms
                        if keyword in topic_text:
                            score += 0.8
                    else:  # Single words
                        score += 0.3

            # Multi-match bonus
            if len(matches) >= 3:
                score *= 1.3
            elif len(matches) >= 2:
                score *= 1.15

            # Strong domain indicator bonus
            strong_indicators = self._get_strong_domain_indicators(domain)
            strong_matches = [word for word in topic_words[:4] if word in strong_indicators]
            if strong_matches:
                score *= 1.2

            domain_scores[domain] = score

        return domain_scores

    def _get_strong_domain_indicators(self, domain):
        """Get strong indicator words for each domain"""
        strong_indicators = {
            'biology': ['gene', 'cell', 'dna', 'protein', 'genome', 'evolution'],
            'medicine': ['patient', 'clinical', 'treatment', 'diagnosis', 'therapy', 'hospital'],
            'chemistry': ['molecule', 'reaction', 'compound', 'synthesis', 'catalyst'],
            'environmental_science': ['climate', 'ecosystem', 'pollution', 'conservation', 'habitat'],
            'psychology': ['behavior', 'cognitive', 'mental', 'brain', 'memory'],
            'physics': ['quantum', 'particle', 'energy', 'mechanics', 'relativity'],
            'materials_science': ['material', 'composite', 'ceramic', 'polymer', 'alloy']
        }
        return strong_indicators.get(domain, [])

    def _resolve_ambiguous_domains(self, domain_scores, topic_words):
        """Resolve ambiguous domain classifications"""
        # Filter out zero scores
        domain_scores = {k: v for k, v in domain_scores.items() if v > 0}

        if not domain_scores:
            return 'unknown'

        sorted_domains = sorted(domain_scores.items(), key=lambda x: x[1], reverse=True)

        if len(sorted_domains) < 2:
            return sorted_domains[0][0]

        best_domain, best_score = sorted_domains[0]
        second_domain, second_score = sorted_domains[1]

        # Clear winner (score difference > 50%)
        if best_score > second_score * 1.5:
            return best_domain

        # Check for domain relationships
        for relationship, related_domains in self.domain_relationships.items():
            if best_domain in related_domains and second_domain in related_domains:
                # Choose based on stronger indicators
                best_indicators = sum(1 for word in topic_words[:4] if word in self._get_strong_domain_indicators(best_domain))
                second_indicators = sum(1 for word in topic_words[:4] if word in self._get_strong_domain_indicators(second_domain))

                if best_indicators > second_indicators:
                    return best_domain
                elif second_indicators > best_indicators:
                    return second_domain

        # Default to highest score
        return best_domain

    def classify_topic_domains(self, topic_model, topics):
        """FIXED domain classification with error handling"""
        print(" Performing domain classification...")

        topic_domain_mapping = {}
        unique_topics = set(topics) - {-1}

        for topic_id in unique_topics:
            try:
                topic_words_data = topic_model.get_topic(topic_id)
                if not topic_words_data:
                    print(f"    No words for topic {topic_id}")
                    continue

                # Get topic words
                all_topic_words = [word for word, _ in topic_words_data[:12]]
                if not all_topic_words:
                    print(f"    Empty topic words for topic {topic_id}")
                    continue

                topic_text = ' '.join(all_topic_words).lower()

                # Calculate enhanced domain scores
                domain_scores = self._enhanced_domain_scoring(all_topic_words, topic_text)

                if not domain_scores:
                    print(f"    No domain scores for topic {topic_id}")
                    continue

                # Resolve ambiguous domains
                best_domain = self._resolve_ambiguous_domains(domain_scores, all_topic_words)

                if best_domain == 'unknown':
                    print(f"    Could not determine domain for topic {topic_id}")
                    continue

                # Calculate confidence
                total_score = sum(domain_scores.values())
                confidence = domain_scores[best_domain] / max(total_score, 1)

                # Adjust confidence based on score dominance
                sorted_scores = sorted(domain_scores.values(), reverse=True)
                if len(sorted_scores) > 1:
                    score_ratio = sorted_scores[0] / sorted_scores[1]
                    if score_ratio > 2.0:
                        confidence = min(1.0, confidence * 1.3)
                    elif score_ratio > 1.5:
                        confidence = min(1.0, confidence * 1.15)

                # Ensure JSON serializable types
                topic_domain_mapping[int(topic_id)] = {
                    'primary_domain': best_domain,
                    'confidence': float(confidence),
                    'all_domain_scores': {k: float(v) for k, v in domain_scores.items()},
                    'topic_keywords': all_topic_words[:8]
                }

                # Display classification
                confidence_level = "HIGH" if confidence > 0.7 else "MEDIUM" if confidence > 0.4 else "LOW"
                top_keywords = ', '.join(all_topic_words[:4])
                print(f"   Topic {topic_id:2d} → {best_domain:20s} (conf: {confidence:.3f} [{confidence_level}]) - {top_keywords}")

            except Exception as e:
                print(f"    Error classifying topic {topic_id}: {str(e)}")
                continue

        self._print_domain_analysis(topic_domain_mapping)
        return topic_domain_mapping

    def _print_domain_analysis(self, domain_mapping):
        """Print domain analysis"""
        if not domain_mapping:
            print(" No topics to analyze")
            return

        domain_counts = Counter()
        confidence_sum = Counter()

        for mapping in domain_mapping.values():
            domain = mapping['primary_domain']
            domain_counts[domain] += 1
            confidence_sum[domain] += mapping['confidence']

        print(f"\n DOMAIN DISTRIBUTION:")
        print("=" * 50)

        total_topics = len(domain_mapping)
        for domain, count in domain_counts.most_common():
            percentage = (count / total_topics) * 100
            avg_confidence = confidence_sum[domain] / count
            confidence_level = "HIGH" if avg_confidence > 0.7 else "MEDIUM" if avg_confidence > 0.5 else "LOW"
            print(f"   {domain:20s}: {count:2d} topics ({percentage:5.1f}%) [avg conf: {avg_confidence:.3f} - {confidence_level}]")



class TopicModeler:
    def __init__(self):
        self.topic_model = None
        self.embeddings = None
        self.topics = None
        self.probabilities = None

    def initialize_model(self):
        """Initialize BERTopic model with OPTIMIZED parameters"""
        print(" Initializing BERTopic model...")

        # UMAP for dimensionality reduction
        umap_model = UMAP(**config.UMAP_PARAMS)

        # HDBSCAN for clustering
        hdbscan_model = HDBSCAN(**config.HDBSCAN_PARAMS)

        # Vectorizer for text processing
        vectorizer_model = CountVectorizer(
            stop_words="english",
            ngram_range=config.BERTOPIC_SETTINGS['n_gram_range'],
            min_df=2,
            max_df=0.95,
            max_features=8000
        )

        # Initialize BERTopic
        self.topic_model = BERTopic(
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            vectorizer_model=vectorizer_model,
            top_n_words=config.BERTOPIC_SETTINGS['top_n_words'],
            min_topic_size=config.BERTOPIC_SETTINGS['min_topic_size'],
            calculate_probabilities=config.BERTOPIC_SETTINGS['calculate_probabilities'],
            verbose=config.BERTOPIC_SETTINGS['verbose'],
            nr_topics=config.BERTOPIC_SETTINGS['nr_topics']
        )

        print(" BERTopic model initialized")

    def fit_model(self, documents):
        """Fit topic model to documents with enhanced processing"""
        print(" Fitting topic model...")

        # Generate embeddings
        model = SentenceTransformer(config.EMBEDDING_MODEL)
        self.embeddings = model.encode(
            documents,
            batch_size=config.BATCH_SIZE,
            show_progress_bar=True,
            normalize_embeddings=True
        )

        # Initialize model
        self.initialize_model()

        # Fit model with error handling
        try:
            self.topics, self.probabilities = self.topic_model.fit_transform(
                documents, self.embeddings
            )
        except Exception as e:
            print(f" Primary fitting failed: {e}")
            print(" Using fallback model...")
            # Fallback with simpler parameters
            self.topic_model = BERTopic(
                min_topic_size=10,
                verbose=False,
                nr_topics=15
            )
            self.topics, self.probabilities = self.topic_model.fit_transform(
                documents, self.embeddings
            )

        # Apply topic optimization if too few topics
        self._optimize_topic_count(documents)

        # Print results
        self._print_results()

        return self.topics, self.probabilities

    def _optimize_topic_count(self, documents):
        """Optimize topic count if too few topics found"""
        unique_topics = len(set(self.topics)) - (1 if -1 in self.topics else 0)
        print(f" Initial topics found: {unique_topics}")

        # If very few topics, try to reduce further
        if unique_topics < 8 and hasattr(self.topic_model, 'reduce_topics'):
            try:
                print(" Attempting to discover more topics...")
                target_topics = min(20, len(documents) // 25)  # Reasonable target
                self.topics, self.probabilities = self.topic_model.reduce_topics(
                    documents, self.topics, self.probabilities, nr_topics=target_topics
                )
                new_unique = len(set(self.topics)) - (1 if -1 in self.topics else 0)
                print(f" Topics after optimization: {new_unique}")
            except Exception as e:
                print(f" Topic optimization failed: {e}")

    def _print_results(self):
        """Print topic modeling results"""
        unique_topics = len(set(self.topics)) - (1 if -1 in self.topics else 0)
        outliers = np.sum(self.topics == -1)

        # Calculate topic statistics
        topic_counts = Counter(self.topics)
        if -1 in topic_counts:
            valid_topics = {k: v for k, v in topic_counts.items() if k != -1}
        else:
            valid_topics = topic_counts

        if valid_topics:
            sizes = list(valid_topics.values())
            stats = {
                'min_size': min(sizes),
                'max_size': max(sizes),
                'avg_size': np.mean(sizes),
                'std_size': np.std(sizes)
            }
        else:
            stats = {'min_size': 0, 'max_size': 0, 'avg_size': 0, 'std_size': 0}

        print(f"\n TOPIC MODELING RESULTS:")
        print(f"   • Topics found: {unique_topics}")
        print(f"   • Outliers: {outliers} ({outliers/len(self.topics)*100:.1f}%)")
        print(f"   • Topic size range: {stats['min_size']} - {stats['max_size']}")
        print(f"   • Average topic size: {stats['avg_size']:.1f}")

        # Show topic details
        if hasattr(self.topic_model, 'get_topic_info'):
            try:
                topic_info = self.topic_model.get_topic_info()
                print(f"\n TOPIC DETAILS:")
                print("=" * 70)

                for _, row in topic_info.iterrows():
                    topic_id = row['Topic']
                    if topic_id != -1:
                        topic_words = self.topic_model.get_topic(topic_id)
                        if topic_words:
                            words = [word for word, _ in topic_words[:6]]
                            size_percentage = (row['Count'] / len(self.topics)) * 100
                            print(f"   Topic {topic_id:2d} ({row['Count']:3d} docs, {size_percentage:4.1f}%): {', '.join(words)}")

            except Exception as e:
                print(f"   • Could not extract topic details: {e}")

# =============================================================================
# IMPROVED MODEL EVALUATOR
# =============================================================================

class ModelEvaluator:
    def __init__(self):
        self.results = {}

    def evaluate_model(self, topic_model, documents, topics, embeddings):
        """IMPROVED model evaluation with better metrics"""
        print(" Performing model evaluation...")

        evaluation = {}

        # Basic statistics
        unique_topics = len(set(topics)) - (1 if -1 in topics else 0)
        outliers = np.sum(topics == -1)

        evaluation['basic_stats'] = {
            'total_documents': int(len(documents)),
            'topics_found': int(unique_topics),
            'outliers': int(outliers),
            'outlier_percentage': float((outliers / len(topics)) * 100)
        }

        # Topic coherence - IMPROVED calculation
        coherence = self._calculate_robust_coherence(topic_model, documents, topics)
        evaluation['coherence_score'] = float(coherence)

        # Topic quality metrics
        topic_quality = self._analyze_topic_quality(topics)
        evaluation['topic_quality'] = topic_quality

        # Additional metrics
        evaluation['additional_metrics'] = self._calculate_additional_metrics(topics, embeddings)

        # Overall score - IMPROVED weighting
        evaluation['overall_score'] = float(self._calculate_improved_overall_score(evaluation))

        self.results = evaluation
        self._print_evaluation_results()

        return evaluation

    def _calculate_robust_coherence(self, topic_model, documents, topics):
        """Robust coherence calculation with better error handling"""
        try:
            # Get topic words
            topic_words = []
            for topic in set(topics):
                if topic != -1:
                    words = topic_model.get_topic(topic)
                    if words and len(words) >= 3:  # Only use topics with enough words
                        topic_words.append([word for word, _ in words[:6]])

            if len(topic_words) < 2:
                return 0.0

            # Tokenize documents
            tokenized_docs = [doc.split() for doc in documents if doc and len(doc.split()) > 10]  # Longer docs only

            if len(tokenized_docs) < 15:
                return 0.0

            # Calculate coherence with robust parameters
            dictionary = corpora.Dictionary(tokenized_docs)
            dictionary.filter_extremes(no_below=3, no_above=0.75)  # Stricter filtering
            corpus = [dictionary.doc2bow(doc) for doc in tokenized_docs]

            if len(corpus) < 10:
                return 0.0

            coherence_model = CoherenceModel(
                topics=topic_words,
                texts=tokenized_docs,
                dictionary=dictionary,
                coherence='c_v',
                processes=1
            )

            coherence_score = coherence_model.get_coherence()
            return max(0.0, float(coherence_score))  # Ensure non-negative

        except Exception as e:
            print(f" Coherence calculation failed: {e}")
            return 0.0

    def _analyze_topic_quality(self, topics):
        """Analyze topic quality metrics"""
        topic_counts = Counter(topics)
        if -1 in topic_counts:
            valid_topics = {k: v for k, v in topic_counts.items() if k != -1}
        else:
            valid_topics = topic_counts

        if not valid_topics:
            return {
                'balance_score': 0.0,
                'size_variation': 0.0,
                'avg_topic_size': 0.0,
                'min_topic_size': 0.0,
                'max_topic_size': 0.0,
                'topic_diversity': 0.0
            }

        sizes = list(valid_topics.values())

        # Calculate topic diversity (how evenly distributed topics are)
        total_docs = sum(sizes)
        if total_docs > 0:
            proportions = [size / total_docs for size in sizes]
            diversity = 1 - sum(p**2 for p in proportions)  # Gini-like diversity
        else:
            diversity = 0.0

        return {
            'balance_score': float(min(sizes) / max(sizes)) if max(sizes) > 0 else 0.0,
            'size_variation': float(np.std(sizes) / np.mean(sizes)) if np.mean(sizes) > 0 else 0.0,
            'avg_topic_size': float(np.mean(sizes)),
            'min_topic_size': float(min(sizes)),
            'max_topic_size': float(max(sizes)),
            'topic_diversity': float(diversity)
        }

    def _calculate_additional_metrics(self, topics, embeddings):
        """Calculate additional evaluation metrics"""
        try:
            # Topic count score (penalize too few or too many topics)
            unique_topics = len(set(topics)) - (1 if -1 in topics else 0)
            if unique_topics <= 5:
                topic_count_score = 0.2
            elif unique_topics <= 10:
                topic_count_score = 0.5
            elif unique_topics <= 20:
                topic_count_score = 0.8
            elif unique_topics <= 30:
                topic_count_score = 0.6
            else:
                topic_count_score = 0.3

            return {
                'topic_count_score': float(topic_count_score),
                'unique_topics': int(unique_topics)
            }
        except:
            return {
                'topic_count_score': 0.0,
                'unique_topics': 0
            }

    def _calculate_improved_overall_score(self, evaluation):
        """IMPROVED overall quality score with better weighting"""
        basic = evaluation['basic_stats']
        coherence = evaluation['coherence_score']
        quality = evaluation['topic_quality']
        additional = evaluation['additional_metrics']

        # Improved weighted scoring
        score = (
            coherence * 0.35 +  # Coherence importance
            (1 - basic['outlier_percentage'] / 100) * 0.25 +  # Low outliers good
            quality['balance_score'] * 0.20 +  # Balanced topics good
            quality['topic_diversity'] * 0.10 +  # Topic diversity
            additional['topic_count_score'] * 0.10  # Appropriate topic count
        )

        return min(max(score, 0.0), 1.0)  # Ensure between 0 and 1

    def _print_evaluation_results(self):
        """Print evaluation results"""
        print(f"\n MODEL EVALUATION:")
        print("=" * 50)

        basic = self.results['basic_stats']
        print(f"   • Documents: {basic['total_documents']}")
        print(f"   • Topics: {basic['topics_found']}")
        print(f"   • Outliers: {basic['outliers']} ({basic['outlier_percentage']:.1f}%)")

        print(f"   • Coherence Score: {self.results['coherence_score']:.3f}")

        quality = self.results['topic_quality']
        print(f"   • Balance Score: {quality['balance_score']:.3f}")
        print(f"   • Topic Diversity: {quality['topic_diversity']:.3f}")
        print(f"   • Avg Topic Size: {quality['avg_topic_size']:.1f}")
        print(f"   • Size Range: {quality['min_topic_size']} - {quality['max_topic_size']}")

        additional = self.results['additional_metrics']
        print(f"   • Topic Count Score: {additional['topic_count_score']:.3f}")

        overall = self.results['overall_score']
        assessment = "EXCELLENT" if overall > 0.7 else "GOOD" if overall > 0.5 else "FAIR" if overall > 0.3 else "POOR"
        print(f"   • Overall Score: {overall:.3f} - {assessment}")



class ResultsVisualizer:
    def __init__(self, results_dir="domain_modeling_results"):
        self.results_dir = results_dir
        self.setup_plot_style()

    def setup_plot_style(self):
        """Setup professional plotting style"""
        plt.style.use('default')
        sns.set_palette("husl")
        self.colors = px.colors.qualitative.Set3

    def load_results(self):
        """Load all saved results"""
        print(" Loading saved results for visualization...")

        try:
            self.doc_assignments = pd.read_csv(f"{self.results_dir}/document_assignments.csv")
            self.domain_mapping = pd.read_csv(f"{self.results_dir}/domain_mapping.csv")
            self.topic_info = pd.read_csv(f"{self.results_dir}/topic_info.csv")

            with open(f"{self.results_dir}/evaluation.json", 'r') as f:
                self.evaluation = json.load(f)

            print(" All results loaded successfully")
            return True

        except Exception as e:
            print(f" Error loading results: {e}")
            return False

    def create_comprehensive_visualizations(self):
        """Create all visualizations"""
        if not self.load_results():
            return

        print("\n CREATING COMPREHENSIVE VISUALIZATIONS...")

        # Create output directory
        viz_dir = f"{self.results_dir}/visualizations"
        os.makedirs(viz_dir, exist_ok=True)

        # Create all visualizations
        self.plot_domain_distribution(viz_dir)
        self.plot_topic_size_distribution(viz_dir)
        self.plot_confidence_analysis(viz_dir)
        self.create_topic_wordclouds(viz_dir)
        self.create_interactive_dashboard(viz_dir)
        self.create_domain_analysis(viz_dir)

        print(f" All visualizations saved to: {viz_dir}")

    def plot_domain_distribution(self, output_dir):
        """Create domain distribution visualizations"""
        print("    Creating domain distribution plots...")

        domain_counts = self.domain_mapping['primary_domain'].value_counts()

        # Matplotlib version
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

        # Pie chart
        colors = plt.cm.Set3(np.linspace(0, 1, len(domain_counts)))
        wedges, texts, autotexts = ax1.pie(domain_counts.values, labels=domain_counts.index,
                                          autopct='%1.1f%%', colors=colors, startangle=90)
        ax1.set_title('Domain Distribution - Topic Count', fontsize=14, fontweight='bold')

        # Bar chart
        y_pos = np.arange(len(domain_counts))
        bars = ax2.barh(y_pos, domain_counts.values, color=colors)
        ax2.set_yticks(y_pos)
        ax2.set_yticklabels(domain_counts.index)
        ax2.set_xlabel('Number of Topics')
        ax2.set_title('Domain Distribution - Topic Count', fontsize=14, fontweight='bold')
        ax2.bar_label(bars, padding=3)

        plt.tight_layout()
        plt.savefig(f"{output_dir}/domain_distribution.png", dpi=300, bbox_inches='tight')
        plt.close()

    def plot_topic_size_distribution(self, output_dir):
        """Create topic size distribution visualizations"""
        print("    Creating topic size distribution plots...")

        topic_sizes = self.topic_info[self.topic_info['Topic'] != -1]['Count']

        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

        # Histogram
        ax1.hist(topic_sizes, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
        ax1.set_xlabel('Documents per Topic')
        ax1.set_ylabel('Frequency')
        ax1.set_title('Topic Size Distribution', fontweight='bold')
        ax1.grid(True, alpha=0.3)

        # Box plot
        ax2.boxplot(topic_sizes, vert=False)
        ax2.set_xlabel('Number of Documents')
        ax2.set_title('Topic Size Distribution (Box Plot)', fontweight='bold')
        ax2.grid(True, alpha=0.3)

        # Top topics bar chart
        top_topics = self.topic_info[self.topic_info['Topic'] != -1].nlargest(10, 'Count')
        y_pos = np.arange(len(top_topics))
        colors = plt.cm.viridis(np.linspace(0, 1, len(top_topics)))

        bars = ax3.barh(y_pos, top_topics['Count'], color=colors)
        ax3.set_yticks(y_pos)
        ax3.set_yticklabels([f"Topic {t}" for t in top_topics['Topic']])
        ax3.set_xlabel('Number of Documents')
        ax3.set_title('Top 10 Largest Topics', fontweight='bold')
        ax3.bar_label(bars, padding=3)
        ax3.invert_yaxis()

        # Size vs Domain
        domain_sizes = self.domain_mapping.merge(
            self.topic_info[['Topic', 'Count']],
            left_on='topic_id', right_on='Topic'
        )
        domain_avg_sizes = domain_sizes.groupby('primary_domain')['Count'].mean().sort_values()

        y_pos = np.arange(len(domain_avg_sizes))
        bars = ax4.barh(y_pos, domain_avg_sizes.values, color=plt.cm.Set3(np.linspace(0, 1, len(domain_avg_sizes))))
        ax4.set_yticks(y_pos)
        ax4.set_yticklabels(domain_avg_sizes.index)
        ax4.set_xlabel('Average Documents per Topic')
        ax4.set_title('Average Topic Size by Domain', fontweight='bold')
        ax4.bar_label(bars, fmt='%.1f', padding=3)

        plt.tight_layout()
        plt.savefig(f"{output_dir}/topic_size_distribution.png", dpi=300, bbox_inches='tight')
        plt.close()

    def plot_confidence_analysis(self, output_dir):
        """Create confidence analysis visualizations"""
        print("    Creating confidence analysis plots...")

        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

        # Confidence distribution
        ax1.hist(self.domain_mapping['confidence'], bins=15, alpha=0.7, color='lightcoral', edgecolor='black')
        ax1.set_xlabel('Confidence Score')
        ax1.set_ylabel('Frequency')
        ax1.set_title('Distribution of Classification Confidence', fontweight='bold')
        ax1.grid(True, alpha=0.3)

        # Confidence by domain
        domain_conf = self.domain_mapping.groupby('primary_domain')['confidence'].mean().sort_values()
        y_pos = np.arange(len(domain_conf))
        colors = plt.cm.RdYlGn(np.linspace(0.2, 0.8, len(domain_conf)))

        bars = ax2.barh(y_pos, domain_conf.values, color=colors)
        ax2.set_yticks(y_pos)
        ax2.set_yticklabels(domain_conf.index)
        ax2.set_xlabel('Average Confidence Score')
        ax2.set_title('Average Confidence by Domain', fontweight='bold')
        ax2.set_xlim(0, 1)
        ax2.bar_label(bars, fmt='%.3f', padding=3)

        # Confidence categories
        conf_categories = pd.cut(self.domain_mapping['confidence'],
                               bins=[0, 0.5, 0.75, 1.0],
                               labels=['Low (<0.5)', 'Medium (0.5-0.75)', 'High (>0.75)'])
        conf_counts = conf_categories.value_counts()

        ax3.pie(conf_counts.values, labels=conf_counts.index, autopct='%1.1f%%',
               colors=['lightcoral', 'gold', 'lightgreen'])
        ax3.set_title('Confidence Level Distribution', fontweight='bold')

        # Domain-wise confidence distribution
        domain_data = []
        for domain in self.domain_mapping['primary_domain'].unique():
            domain_conf = self.domain_mapping[self.domain_mapping['primary_domain'] == domain]['confidence']
            domain_data.append(domain_conf.values)

        ax4.boxplot(domain_data, labels=self.domain_mapping['primary_domain'].unique())
        ax4.set_xticklabels(self.domain_mapping['primary_domain'].unique(), rotation=45)
        ax4.set_ylabel('Confidence Score')
        ax4.set_title('Confidence Distribution by Domain', fontweight='bold')
        ax4.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f"{output_dir}/confidence_analysis.png", dpi=300, bbox_inches='tight')
        plt.close()

    def create_topic_wordclouds(self, output_dir):
        """Create word clouds for each domain"""
        print("    Creating topic word clouds...")

        # Group topics by domain and create combined word lists
        domain_words = {}

        for _, row in self.domain_mapping.iterrows():
            domain = row['primary_domain']
            keywords = row['topic_keywords'].split(', ')

            if domain not in domain_words:
                domain_words[domain] = []

            domain_words[domain].extend(keywords)

        # Create word clouds for each domain
        fig, axes = plt.subplots(2, 3, figsize=(20, 12))
        axes = axes.flatten()

        for idx, (domain, words) in enumerate(list(domain_words.items())[:6]):
            if idx < len(axes):
                word_freq = Counter(words)
                wordcloud = WordCloud(width=400, height=300,
                                    background_color='white',
                                    colormap='viridis').generate_from_frequencies(word_freq)

                axes[idx].imshow(wordcloud, interpolation='bilinear')
                axes[idx].set_title(f'{domain}\n({len(words)} keywords)', fontweight='bold')
                axes[idx].axis('off')

        # Hide unused subplots
        for idx in range(len(domain_words), len(axes)):
            axes[idx].axis('off')

        plt.tight_layout()
        plt.savefig(f"{output_dir}/domain_wordclouds.png", dpi=300, bbox_inches='tight')
        plt.close()

    def create_interactive_dashboard(self, output_dir):
        """Create interactive Plotly dashboard"""
        print("    Creating interactive dashboard...")

        # Prepare data
        topic_data = self.topic_info[self.topic_info['Topic'] != -1].merge(
            self.domain_mapping, left_on='Topic', right_on='topic_id'
        )

        # Create interactive scatter plot
        fig = px.scatter(topic_data,
                        x='Count',
                        y='confidence',
                        size='Count',
                        color='primary_domain',
                        hover_data=['Topic', 'topic_keywords'],
                        title='Topic Analysis: Size vs Confidence by Domain',
                        labels={'Count': 'Number of Documents', 'confidence': 'Classification Confidence'},
                        size_max=60)

        fig.update_layout(showlegend=True)
        fig.write_html(f"{output_dir}/interactive_topic_analysis.html")

    def create_domain_analysis(self, output_dir):
        """Create detailed domain-wise analysis"""
        print("    Creating domain-wise analysis...")

        domain_data = self.domain_mapping.merge(
            self.topic_info[['Topic', 'Count', 'Name']],
            left_on='topic_id', right_on='Topic'
        )

        domains = domain_data['primary_domain'].unique()

        for domain in domains:
            domain_topics = domain_data[domain_data['primary_domain'] == domain]

            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

            # Topic sizes for this domain
            topics_sorted = domain_topics.nlargest(10, 'Count')
            y_pos = np.arange(len(topics_sorted))

            bars = ax1.barh(y_pos, topics_sorted['Count'],
                           color=plt.cm.Blues(np.linspace(0.4, 0.8, len(topics_sorted))))
            ax1.set_yticks(y_pos)
            ax1.set_yticklabels([f"Topic {t}" for t in topics_sorted['Topic']])
            ax1.set_xlabel('Number of Documents')
            ax1.set_title(f'{domain} - Topic Sizes', fontweight='bold')
            ax1.bar_label(bars, padding=3)
            ax1.invert_yaxis()

            # Confidence distribution for this domain
            ax2.hist(domain_topics['confidence'], bins=10, alpha=0.7,
                    color='lightgreen', edgecolor='black')
            ax2.set_xlabel('Confidence Score')
            ax2.set_ylabel('Number of Topics')
            ax2.set_title(f'{domain} - Confidence Distribution', fontweight='bold')
            ax2.grid(True, alpha=0.3)

            plt.tight_layout()
            plt.savefig(f"{output_dir}/domain_analysis_{domain.lower().replace(' ', '_')}.png",
                       dpi=300, bbox_inches='tight')
            plt.close()

    def generate_comprehensive_report(self):
        """Generate comprehensive text report"""
        print("\n GENERATING COMPREHENSIVE REPORT...")

        report_lines = []

        # Header
        report_lines.append("="*80)
        report_lines.append("SCIENTIFIC DOMAIN MODELING REPORT")
        report_lines.append("="*80)
        report_lines.append(f"Generated: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}")
        report_lines.append(f"Total Documents: {len(self.doc_assignments)}")
        report_lines.append(f"Total Topics: {len(self.domain_mapping)}")
        report_lines.append("")

        # Executive Summary
        report_lines.append("EXECUTIVE SUMMARY")
        report_lines.append("-"*40)
        report_lines.append(f"Overall Quality Score: {self.evaluation['overall_score']:.3f}")
        report_lines.append(f"Topic Coherence: {self.evaluation['coherence_score']:.3f}")
        report_lines.append(f"Outlier Percentage: {self.evaluation['basic_stats']['outlier_percentage']:.1f}%")
        report_lines.append("")

        # Domain Distribution
        report_lines.append("DOMAIN DISTRIBUTION")
        report_lines.append("-"*40)
        domain_counts = self.domain_mapping['primary_domain'].value_counts()
        for domain, count in domain_counts.items():
            percentage = (count / len(self.domain_mapping)) * 100
            report_lines.append(f"{domain:25} {count:2d} topics ({percentage:5.1f}%)")
        report_lines.append("")

        # Save report
        report_path = f"{self.results_dir}/comprehensive_report.txt"
        with open(report_path, 'w') as f:
            f.write('\n'.join(report_lines))

        print(f" Comprehensive report saved to: {report_path}")

# =============================================================================
# MAIN PIPELINE
# =============================================================================

class DomainModelingPipeline:
    def __init__(self):
        self.data_processor = DataProcessor()
        self.topic_modeler = TopicModeler()
        self.domain_classifier = EnhancedDomainClassifier()
        self.evaluator = ModelEvaluator()
        self.visualizer = ResultsVisualizer()

    def run_pipeline(self):
        """Run complete domain modeling pipeline"""
        print(" STARTING DOMAIN MODELING PIPELINE")
        print("=" * 60)

        try:
            # Step 1: Data processing
            print("\n STEP 1: DATA PROCESSING")
            df = self.data_processor.load_and_process_data()
            documents = df['cleaned_text'].tolist()
            print(f" Processing {len(documents)} documents")

            # Step 2: Topic modeling
            print("\n STEP 2: TOPIC MODELING")
            topics, probabilities = self.topic_modeler.fit_model(documents)

            # Step 3: Domain classification
            print("\n STEP 3: DOMAIN CLASSIFICATION")
            domain_mapping = self.domain_classifier.classify_topic_domains(
                self.topic_modeler.topic_model, topics
            )

            # Step 4: Evaluation
            print("\n STEP 4: MODEL EVALUATION")
            evaluation = self.evaluator.evaluate_model(
                self.topic_modeler.topic_model, documents, topics, self.topic_modeler.embeddings
            )

            # Step 5: Save results
            print("\n STEP 5: SAVING RESULTS")
            self._save_results(df, topics, domain_mapping, evaluation)

            # Step 6: Create visualizations and reports
            print("\n STEP 6: CREATING VISUALIZATIONS & REPORTS")
            self.visualizer.create_comprehensive_visualizations()
            self.visualizer.generate_comprehensive_report()

            print("\n PIPELINE COMPLETED SUCCESSFULLY!")

            # Final summary
            self._print_final_summary(evaluation, domain_mapping)

            return {
                'topics': topics,
                'domain_mapping': domain_mapping,
                'evaluation': evaluation,
                'model': self.topic_modeler.topic_model,
                'documents': documents
            }

        except Exception as e:
            print(f" Pipeline execution failed: {e}")
            import traceback
            traceback.print_exc()
            return None

    def _print_final_summary(self, evaluation, domain_mapping):
        """Print final summary"""
        print(f"\n FINAL SUMMARY:")
        print("=" * 50)

        topics_found = evaluation['basic_stats']['topics_found']
        overall_score = evaluation['overall_score']

        # Extract unique domains correctly
        unique_domains = set(mapping['primary_domain'] for mapping in domain_mapping.values())
        domains_found = len(unique_domains)

        print(f"   • Topics Discovered: {topics_found}")
        print(f"   • Domains Identified: {domains_found}")
        print(f"   • Overall Quality: {overall_score:.3f}")

        # Domain breakdown
        domain_counts = Counter(mapping['primary_domain'] for mapping in domain_mapping.values())
        print(f"   • Domain Distribution:")
        for domain, count in domain_counts.most_common():
            percentage = (count / topics_found) * 100
            print(f"      - {domain}: {count} topics ({percentage:.1f}%)")

    def _save_results(self, df, topics, domain_mapping, evaluation):
        """Save all results to files"""
        try:
            # Create document assignments dataframe
            results_df = df.copy()
            results_df['topic'] = topics
            results_df['domain'] = results_df['topic'].map(
                {k: v['primary_domain'] for k, v in domain_mapping.items()}
            )
            results_df['confidence'] = results_df['topic'].map(
                {k: v['confidence'] for k, v in domain_mapping.items()}
            )

            # Save document assignments
            results_df.to_csv(f"{config.OUTPUT_DIR}/document_assignments.csv", index=False)
            print(f" Document assignments saved")

            # Save topic model
            self.topic_modeler.topic_model.save(f"{config.OUTPUT_DIR}/topic_model")
            print(f" Topic model saved")

            # Save domain mapping
            domain_df = pd.DataFrame([
                {
                    'topic_id': topic_id,
                    'primary_domain': info['primary_domain'],
                    'confidence': info['confidence'],
                    'topic_keywords': ', '.join(info['topic_keywords'])
                }
                for topic_id, info in domain_mapping.items()
            ])
            domain_df.to_csv(f"{config.OUTPUT_DIR}/domain_mapping.csv", index=False)
            print(f" Domain mapping saved")

            # Save evaluation with proper serialization
            with open(f"{config.OUTPUT_DIR}/evaluation.json", 'w') as f:
                json.dump(evaluation, f, indent=2, default=self._json_serializer)
            print(f" Evaluation results saved")

            # Save topic information
            if hasattr(self.topic_modeler.topic_model, 'get_topic_info'):
                topic_info = self.topic_modeler.topic_model.get_topic_info()
                topic_info.to_csv(f"{config.OUTPUT_DIR}/topic_info.csv", index=False)
                print(f" Topic information saved")

            print(f" All results saved to: {config.OUTPUT_DIR}")

        except Exception as e:
            print(f" Error saving results: {e}")

    def _json_serializer(self, obj):
        """JSON serializer for objects not serializable by default json code"""
        if isinstance(obj, (np.integer, np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.ndarray,)):
            return obj.tolist()
        elif isinstance(obj, (np.bool_)):
            return bool(obj)
        raise TypeError(f"Object of type {type(obj)} is not JSON serializable")

# =============================================================================
# EXECUTION
# =============================================================================

if __name__ == "__main__":
    # Configure logging
    logging.getLogger("bertopic").setLevel(logging.WARNING)
    logging.getLogger("umap").setLevel(logging.WARNING)
    logging.getLogger("hdbscan").setLevel(logging.WARNING)

    print(" OPTIMIZED DOMAIN MODELING PIPELINE")
    print("=" * 60)
    print(f"   • Embedding Model: {config.EMBEDDING_MODEL}")
    print(f"   • Output Directory: {config.OUTPUT_DIR}")
    print(f"   • Features: Optimized Topic Modeling + Fixed Domain Classification + Visualization")
    print("=" * 60)

    # Execute pipeline
    pipeline = DomainModelingPipeline()
    results = pipeline.run_pipeline()

    if results is not None:
        print(f"\n PIPELINE COMPLETED SUCCESSFULLY!")
        print(f" Check '{config.OUTPUT_DIR}' folder for all results and visualizations")
    else:
        print("\n Pipeline execution failed")

In [None]:
import logging
import pandas as pd
import numpy as np
import os
from datetime import datetime, timezone
import torch
from bertopic import BERTopic
from umap import UMAP
from hdbscan import HDBSCAN
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import silhouette_score
import pickle
import re
import json
from collections import Counter
import gensim
from gensim import corpora
from gensim.models.coherencemodel import CoherenceModel
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from wordcloud import WordCloud
from gensim.models import LdaModel
from sklearn.decomposition import LatentDirichletAllocation
import time
from typing import Dict, List, Optional, Tuple
from tqdm import tqdm

warnings.filterwarnings('ignore')

print("All imports completed successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Pandas version: {pd.__version__}")


class EnhancedDomainConfig:
    # Paths
    PROCESSED_TEXT_CSV = "updated.csv"
    OUTPUT_DIR = "domain_modeling_results"
    COMPARISON_DIR = "model_comparison"

    # Embedding models
    EMBEDDING_MODELS = {
        'primary': "sentence-transformers/all-mpnet-base-v2",
        'backup': "sentence-transformers/all-MiniLM-L12-v2"
    }

    # Optimized UMAP parameters
    UMAP_PARAMS = {
        'n_neighbors': 10,
        'n_components': 5,
        'min_dist': 0.05,
        'metric': 'cosine',
        'random_state': 42,
        'low_memory': False
    }

    # Optimized HDBSCAN parameters
    HDBSCAN_PARAMS = {
        'min_cluster_size': 10,
        'min_samples': 3,
        'cluster_selection_epsilon': 0.03,
        'metric': 'euclidean',
        'cluster_selection_method': 'eom',
        'prediction_data': True
    }

    # BERTopic settings
    BERTOPIC_SETTINGS = {
        'top_n_words': 12,
        'n_gram_range': (1, 2),
        'min_topic_size': 10,
        'calculate_probabilities': True,
        'verbose': False,
        'nr_topics': 20
    }

    # Enhanced LDA parameters
    LDA_PARAMS = {
        'n_components': 20,
        'random_state': 42,
        'learning_method': 'online',
        'max_iter': 25,
        'batch_size': 128,
        'evaluate_every': 5
    }

    # Processing settings
    MIN_DOC_LENGTH = 50
    MAX_DOC_LENGTH = 2500
    BATCH_SIZE = 32

    # Evaluation settings
    TOPIC_EVALUATION_SAMPLE_SIZE = 3  # Number of docs per topic to sample

    def __init__(self):
        os.makedirs(self.OUTPUT_DIR, exist_ok=True)
        os.makedirs(self.COMPARISON_DIR, exist_ok=True)

config = EnhancedDomainConfig()
print(f"Enhanced Configuration initialized")


class EnhancedDataProcessor:
    def __init__(self):
        self.scientific_stop_words = set([
            'paper', 'study', 'research', 'result', 'method', 'approach',
            'show', 'demonstrate', 'present', 'investigate', 'analyze',
            'discuss', 'conclude', 'suggest', 'indicate', 'figure', 'table',
            'author', 'journal', 'publication', 'reference', 'citation',
            'section', 'abstract', 'introduction', 'background', 'conclusion',
            'however', 'therefore', 'moreover', 'furthermore', 'additionally'
        ])

        self.technical_terms = {
            'biology': ['gene', 'protein', 'cell', 'dna', 'rna', 'genome', 'organism'],
            'medicine': ['patient', 'clinical', 'treatment', 'diagnosis', 'therapy', 'drug'],
            'chemistry': ['molecule', 'compound', 'reaction', 'chemical', 'synthesis'],
            'physics': ['quantum', 'particle', 'energy', 'field', 'mechanics'],
            'computer_science': ['algorithm', 'neural', 'network', 'data', 'system'],
            'engineering': ['design', 'system', 'manufacturing', 'structural'],
            'materials_science': ['material', 'composite', 'polymer', 'ceramic'],
            'psychology': ['behavior', 'cognitive', 'mental', 'brain', 'memory']
        }

    def enhanced_clean_text(self, text):
        """Enhanced text cleaning with technical term preservation"""
        if not isinstance(text, str):
            return ""

        text = text.lower().strip()

        if len(text) < config.MIN_DOC_LENGTH:
            return ""
        if len(text) > config.MAX_DOC_LENGTH:
            text = text[:config.MAX_DOC_LENGTH]

        # Remove technical patterns
        patterns = [
            (r'http\S+|www\S+|https\S+', ' '),
            (r'doi:\s*\S+', ' '),
            (r'arXiv:\s*\d+\.\d+', ' '),
            (r'©\s*\d{4}', ' '),
            (r'fig\.?\s*\d+', ' '),
            (r'table\s*\d+', ' '),
            (r'equation\s*\d+', ' '),
            (r'\s+', ' ')
        ]

        for pattern, replacement in patterns:
            text = re.sub(pattern, replacement, text)

        # Tokenize with preservation of technical terms
        words = text.split()
        cleaned_words = []

        for word in words:
            if self._should_preserve_word(word):
                cleaned_words.append(word)

        cleaned_text = ' '.join(cleaned_words).strip()

        return cleaned_text if len(cleaned_text) >= config.MIN_DOC_LENGTH else ""

    def _should_preserve_word(self, word):
        """Determine if word should be preserved"""
        # Check if it's a technical term
        for domain, terms in self.technical_terms.items():
            if word in terms:
                return True

        # Basic filtering
        if (len(word) <= 2 or len(word) >= 30 or
            word in self.scientific_stop_words or
            word.isdigit()):
            return False

        # Keep reasonable words
        return len(word) > 3 and any(c.isalpha() for c in word)

    def load_and_process_data(self):
        """Load and process data with enhanced cleaning"""
        try:
            print(" Loading and processing data with enhanced cleaning...")

            df = pd.read_csv(config.PROCESSED_TEXT_CSV)
            print(f" Loaded {len(df)} documents")

            df['cleaned_text'] = df['processed_text'].apply(self.enhanced_clean_text)

            initial_count = len(df)
            df = df[df['cleaned_text'].str.len() > config.MIN_DOC_LENGTH].copy()
            final_count = len(df)

            if initial_count != final_count:
                removed = initial_count - final_count
                print(f" Removed {removed} documents after cleaning ({removed/initial_count*100:.1f}%)")

            # Text statistics
            word_counts = df['cleaned_text'].str.split().str.len()
            print(f" Final dataset: {final_count} documents")
            print(f" Average document length: {word_counts.mean():.1f} words")
            print(f" Total words: {word_counts.sum():,}")

            return df

        except Exception as e:
            print(f" Data loading failed: {e}")
            raise


class FixedLDAModel:
    """Fixed LDA model with proper topic extraction"""
    def __init__(self):
        self.lda_model = None
        self.vectorizer = None
        self.dictionary = None
        self.corpus = None
        self.feature_names = None
        self.perplexity_score = None
        self.coherence_score = None

    def fit(self, documents):
        """Train enhanced LDA model"""
        print(" Training Fixed LDA model...")
        start_time = time.time()

        # Enhanced vectorizer
        self.vectorizer = CountVectorizer(
            max_df=0.9,
            min_df=2,
            stop_words='english',
            ngram_range=(1, 2),
            max_features=15000,
            lowercase=True,
            strip_accents='unicode'
        )

        X = self.vectorizer.fit_transform(documents)
        self.feature_names = self.vectorizer.get_feature_names_out()
        print(f"    Vocabulary size: {X.shape[1]:,} features")

        # Train LDA with early stopping
        self.lda_model = LatentDirichletAllocation(**config.LDA_PARAMS)
        self.lda_model.fit(X)

        # Calculate perplexity
        self.perplexity_score = self.lda_model.perplexity(X)

        # Prepare for coherence calculation
        tokenized_docs = [doc.split() for doc in documents]
        self.dictionary = corpora.Dictionary(tokenized_docs)
        self.dictionary.filter_extremes(no_below=3, no_above=0.85)
        self.corpus = [self.dictionary.doc2bow(doc) for doc in tokenized_docs]

        # Train gensim LDA for coherence
        self.gensim_lda = LdaModel(
            corpus=self.corpus,
            id2word=self.dictionary,
            num_topics=config.LDA_PARAMS['n_components'],
            random_state=42,
            update_every=1,
            chunksize=100,
            passes=10,
            alpha='auto',
            per_word_topics=True
        )

        end_time = time.time()
        print(f" LDA training completed in {end_time - start_time:.2f}s")

    def get_topic_words(self, topic_id, n_words=10):
        """Extract actual topic words from LDA"""
        if self.lda_model is None:
            return []

        # Get topic weights
        topic_weights = self.lda_model.components_[topic_id]
        # Get indices of top words
        top_indices = topic_weights.argsort()[:-n_words-1:-1]
        # Map indices to actual words
        topic_words = [self.feature_names[i] for i in top_indices]

        return topic_words

    def get_topics_dict(self, n_words=10):
        """Get all topics as dictionary"""
        topics = {}
        for i in range(config.LDA_PARAMS['n_components']):
            topics[i] = self.get_topic_words(i, n_words)
        return topics

    def predict_topics(self, documents):
        """Predict topic distributions"""
        X = self.vectorizer.transform(documents)
        topic_distributions = self.lda_model.transform(X)
        return topic_distributions.argmax(axis=1)

    def evaluate(self, documents):
        """Comprehensive LDA evaluation"""
        print(" Evaluating Fixed LDA model...")

        # Coherence score
        tokenized_docs = [doc.split() for doc in documents]

        coherence_model = CoherenceModel(
            model=self.gensim_lda,
            texts=tokenized_docs,
            dictionary=self.dictionary,
            coherence='c_v'
        )
        coherence = coherence_model.get_coherence()
        self.coherence_score = coherence

        # Topic quality metrics
        topics = self.get_topics_dict(n_words=10)
        all_words = []
        for topic_words in topics.values():
            all_words.extend(topic_words[:5])

        unique_words = len(set(all_words))
        total_words = len(all_words)
        diversity = unique_words / total_words if total_words > 0 else 0

        # Topic distinctiveness
        topic_matrices = self.lda_model.components_
        distinctiveness = self._calculate_topic_distinctiveness(topic_matrices)

        evaluation = {
            'coherence_score': float(coherence),
            'perplexity_score': float(self.perplexity_score),
            'topic_diversity': float(diversity),
            'topic_distinctiveness': float(distinctiveness),
            'num_topics': config.LDA_PARAMS['n_components'],
            'model_type': 'LDA_FIXED'
        }

        print(f"   • Coherence: {coherence:.3f}")
        print(f"   • Perplexity: {self.perplexity_score:.2f}")
        print(f"   • Topic Diversity: {diversity:.3f}")
        print(f"   • Topic Distinctiveness: {distinctiveness:.3f}")

        return evaluation

    def _calculate_topic_distinctiveness(self, topic_matrices):
        """Calculate how distinct topics are from each other"""
        n_topics = topic_matrices.shape[0]
        similarities = []

        for i in range(n_topics):
            for j in range(i+1, n_topics):
                # Cosine similarity between topics
                cos_sim = np.dot(topic_matrices[i], topic_matrices[j]) / (
                    np.linalg.norm(topic_matrices[i]) * np.linalg.norm(topic_matrices[j])
                )
                similarities.append(cos_sim)

        if similarities:
            avg_similarity = np.mean(similarities)
            return 1 - avg_similarity  # Higher is better (more distinct)
        return 0.0


class EnhancedTopicModeler:
    """Enhanced BERTopic modeler"""
    def __init__(self):
        self.topic_model = None
        self.embeddings = None
        self.topics = None
        self.probabilities = None

    def initialize_enhanced_model(self):
        """Initialize enhanced BERTopic model"""
        print(" Initializing enhanced BERTopic model...")

        try:
            umap_model = UMAP(**config.UMAP_PARAMS)
            hdbscan_model = HDBSCAN(**config.HDBSCAN_PARAMS)

            vectorizer_model = CountVectorizer(
                stop_words="english",
                ngram_range=config.BERTOPIC_SETTINGS['n_gram_range'],
                min_df=2,
                max_df=0.9,
                max_features=15000,
                lowercase=True,
                strip_accents='unicode'
            )

            self.topic_model = BERTopic(
                umap_model=umap_model,
                hdbscan_model=hdbscan_model,
                vectorizer_model=vectorizer_model,
                top_n_words=config.BERTOPIC_SETTINGS['top_n_words'],
                min_topic_size=config.BERTOPIC_SETTINGS['min_topic_size'],
                calculate_probabilities=config.BERTOPIC_SETTINGS['calculate_probabilities'],
                verbose=config.BERTOPIC_SETTINGS['verbose'],
                nr_topics=config.BERTOPIC_SETTINGS['nr_topics']
            )

            print(" Enhanced BERTopic model initialized successfully")

        except Exception as e:
            print(f" Enhanced initialization failed: {e}")
            self._initialize_fallback_model()

    def _initialize_fallback_model(self):
        """Initialize fallback model"""
        print(" Initializing fallback model...")
        self.topic_model = BERTopic(
            min_topic_size=15,
            verbose=False,
            calculate_probabilities=True
        )

    def fit_enhanced_model(self, documents):
        """Fit enhanced topic model"""
        print(" Fitting enhanced topic model...")

        # Generate embeddings with primary model
        try:
            model = SentenceTransformer(config.EMBEDDING_MODELS['primary'])
            self.embeddings = model.encode(
                documents,
                batch_size=config.BATCH_SIZE,
                show_progress_bar=True,
                normalize_embeddings=True
            )
        except Exception as e:
            print(f" Primary embedding failed, using backup: {e}")
            model = SentenceTransformer(config.EMBEDDING_MODELS['backup'])
            self.embeddings = model.encode(
                documents,
                batch_size=config.BATCH_SIZE,
                show_progress_bar=True,
                normalize_embeddings=True
            )

        self.initialize_enhanced_model()

        try:
            self.topics, self.probabilities = self.topic_model.fit_transform(
                documents, self.embeddings
            )
            print(" Enhanced model fitting completed successfully")
        except Exception as e:
            print(f" Enhanced fitting failed: {e}")
            self._fit_fallback_model(documents)

        self._optimize_topics(documents)
        self._print_enhanced_results()

        return self.topics, self.probabilities

    def _fit_fallback_model(self, documents):
        """Fit fallback model"""
        print(" Fitting fallback model...")
        self.topic_model = BERTopic(
            min_topic_size=20,
            verbose=False,
            calculate_probabilities=False
        )
        self.topics, self.probabilities = self.topic_model.fit_transform(documents)

    def _optimize_topics(self, documents):
        """Optimize topics based on quality metrics"""
        unique_topics = len(set(self.topics)) - (1 if -1 in self.topics else 0)
        print(f" Found {unique_topics} initial topics")

        if unique_topics < 6 and hasattr(self.topic_model, 'reduce_topics'):
            try:
                target_topics = max(8, min(25, len(documents) // 30))
                print(f" Optimizing to {target_topics} topics...")
                self.topics, self.probabilities = self.topic_model.reduce_topics(
                    documents, self.topics, self.probabilities, nr_topics=target_topics
                )
                print(f" Optimized to {len(set(self.topics)) - (1 if -1 in self.topics else 0)} topics")
            except Exception as e:
                print(f" Topic optimization failed: {e}")

    def _print_enhanced_results(self):
        """Print enhanced topic modeling results"""
        unique_topics = len(set(self.topics)) - (1 if -1 in self.topics else 0)
        outliers = np.sum(self.topics == -1)
        outlier_percentage = (outliers / len(self.topics)) * 100

        topic_counts = Counter(self.topics)
        valid_topics = {k: v for k, v in topic_counts.items() if k != -1}

        if valid_topics:
            sizes = list(valid_topics.values())
            stats = {
                'min_size': min(sizes),
                'max_size': max(sizes),
                'avg_size': np.mean(sizes),
                'std_size': np.std(sizes),
                'median_size': np.median(sizes)
            }
        else:
            stats = {'min_size': 0, 'max_size': 0, 'avg_size': 0, 'std_size': 0, 'median_size': 0}

        print(f"\n ENHANCED TOPIC MODELING RESULTS:")
        print(f"   • Topics discovered: {unique_topics}")
        print(f"   • Outliers: {outliers} ({outlier_percentage:.1f}%)")
        print(f"   • Topic size range: {stats['min_size']} - {stats['max_size']}")
        print(f"   • Average topic size: {stats['avg_size']:.1f}")
        print(f"   • Median topic size: {stats['median_size']:.1f}")

        self._assess_topic_quality(valid_topics)
        self._display_sample_topics()

    def _assess_topic_quality(self, valid_topics):
        """Assess overall topic quality"""
        if not valid_topics:
            return

        sizes = list(valid_topics.values())
        balance_ratio = min(sizes) / max(sizes) if max(sizes) > 0 else 0
        cv = np.std(sizes) / np.mean(sizes) if np.mean(sizes) > 0 else 0

        print(f"   • Topic balance ratio: {balance_ratio:.3f}")
        print(f"   • Size coefficient of variation: {cv:.3f}")

        if balance_ratio > 0.15 and cv < 1.0:
            print("   •  Excellent topic distribution")
        elif balance_ratio > 0.05:
            print("   •  Good topic distribution")
        else:
            print("   •  Poor topic distribution - consider parameter adjustment")

    def _display_sample_topics(self):
        """Display sample topics for quality inspection"""
        if hasattr(self.topic_model, 'get_topic_info'):
            try:
                topic_info = self.topic_model.get_topic_info()
                valid_topics = topic_info[topic_info['Topic'] != -1]

                print(f"\n TOPIC OVERVIEW (showing {min(8, len(valid_topics))} sample topics):")
                print("=" * 80)

                for _, row in valid_topics.head(8).iterrows():
                    topic_id = row['Topic']
                    topic_words = self.topic_model.get_topic(topic_id)
                    if topic_words:
                        words = [word for word, _ in topic_words[:6]]
                        size_percentage = (row['Count'] / len(self.topics)) * 100
                        print(f"   Topic {topic_id:2d} ({row['Count']:3d} docs, {size_percentage:4.1f}%): {', '.join(words)}")

            except Exception as e:
                print(f"   • Could not extract topic details: {e}")


class EnhancedDomainClassifier:
    """Enhanced domain classifier with improved keyword matching"""
    def __init__(self):
        # Comprehensive domain keywords with FIXED coverage
        self.domain_keywords = {
            'biology': [
                'cell', 'gene', 'protein', 'dna', 'genetic', 'molecular', 'organism',
                'evolution', 'genome', 'species', 'ecological', 'biodiversity',
                'microbial', 'enzyme', 'metabolism', 'phylogenetic', 'transcription',
                'rna', 'chromosome', 'mitochondria', 'apoptosis', 'sequence',
                'mutation', 'expression', 'cellular', 'developmental', 'plant',
                'animal', 'bacterial', 'viral', 'evolutionary', 'population',
                'biological', 'physiology', 'genomic', 'proteomic', 'transcriptomic',
                'microbiome', 'neuroscience', 'zoology', 'botany', 'ecology',
                'soybean', 'fruit', 'circrnas', 'drosophila', 'nutrient', 'invasive',
                'breast', 'strain', 'speech', 'brain', 'neural', 'jet', 'physics',
                'collision', 'proton', 'oxygen', 'sex', 'temperature', 'thermal'
            ],
            'medicine': [
                'patient', 'clinical', 'treatment', 'disease', 'medical', 'therapy',
                'health', 'drug', 'vaccine', 'diagnosis', 'symptom', 'hospital',
                'pharmaceutical', 'epidemiology', 'pathology', 'oncology', 'immunology',
                'surgery', 'prognosis', 'biomarker', 'clinical trial', 'pharmacology',
                'therapeutic', 'dosage', 'recovery', 'mortality', 'morbidity', 'cancer',
                'tumor', 'infection', 'inflammatory', 'neurological', 'cardiology',
                'pediatric', 'geriatric', 'psychiatry', 'radiology', 'anesthesia',
                'public health', 'virology', 'bacteriology', 'dose', 'imaging',
                'muscle', 'care', 'inflammation', 'mental', 'virus', 'sars', 'vector',
                'mouse', 'dam', 'like', 'wash', 'londrina'
            ],
            'chemistry': [
                'molecule', 'reaction', 'compound', 'chemical', 'synthesis',
                'catalyst', 'polymer', 'organic', 'inorganic', 'spectroscopy',
                'chromatography', 'crystallography', 'stoichiometry', 'kinetics',
                'nmr', 'mass spectrometry', 'electrochemistry', 'photochemistry',
                'reagent', 'solvent', 'yield', 'purification', 'characterization',
                'crystal', 'bond', 'structure', 'atomic', 'molecular', 'analytical',
                'physical chemistry', 'quantum chemistry', 'medicinal chemistry',
                'biochemistry', 'polymer chemistry', 'materials chemistry',
                'irrigation', 'management', 'soil', 'water'
            ],
            'environmental_science': [
                'environmental', 'climate', 'ecosystem', 'sustainability', 'pollution',
                'conservation', 'biodiversity', 'ecological', 'environmental impact',
                'climate change', 'environmental management', 'sustainable development',
                'habitat', 'wildlife', 'conservation', 'environmental policy', 'earth',
                'geological', 'ocean', 'atmospheric', 'marine', 'terrestrial',
                'agricultural', 'forestry', 'water resources', 'air quality', 'soil science',
                'environmental engineering', 'conservation biology', 'environmental health',
                'ice', 'university', 'usa', 'fish', 'population', 'ecology',
                'geological survey', 'santa', 'survey', 'rupture'
            ],
            'computer_science': [
                'algorithm', 'software', 'programming', 'machine learning',
                'neural network', 'data', 'system', 'computational', 'artificial intelligence',
                'database', 'network', 'optimization', 'cybersecurity', 'blockchain',
                'deep learning', 'computer vision', 'natural language processing',
                'computation', 'modeling', 'simulation', 'data analysis', 'big data',
                'cloud computing', 'internet of things', 'robotics', 'automation'
            ],
            'physics': [
                'quantum', 'particle', 'energy', 'field', 'mechanics', 'astrophysics',
                'relativity', 'nuclear', 'optics', 'thermodynamics', 'electromagnetic',
                'condensed matter', 'cosmology', 'entanglement', 'superconductivity',
                'atomic', 'molecular', 'theoretical', 'experimental', 'quantum mechanics',
                'statistical mechanics', 'fluid dynamics', 'plasma physics', 'optics',
                'astronomy', 'cosmology', 'particle physics', 'solid state physics'
            ],
            'engineering': [
                'design', 'system', 'manufacturing', 'structural', 'electrical',
                'mechanical', 'control', 'sensor', 'robotics', 'automation',
                'aerospace', 'civil', 'materials', 'nanotechnology', 'mechatronics',
                'biomedical', 'chemical engineering', 'environmental engineering'
            ],
            'materials_science': [
                'material', 'composite', 'polymer', 'ceramic', 'metal', 'alloy',
                'nanomaterial', 'crystal', 'structure', 'properties', 'synthesis',
                'fabrication', 'characterization', 'mechanical properties', 'thermal properties',
                'electronic properties', 'optical properties', 'material design',
                'heat', 'strength', 'stimulus', 'group'
            ],
            'psychology': [
                'behavior', 'cognitive', 'psychological', 'mental', 'personality',
                'emotion', 'memory', 'neural', 'brain', 'perception', 'learning',
                'cognition', 'developmental', 'social psychology', 'clinical psychology',
                'behavioral', 'psychiatric', 'neuroscience', 'cognitive science',
                'speech', 'elife', 'neuroscience', 'university'
            ]
        }

        # Domain relationships for ambiguous cases
        self.domain_relationships = {
            'ecology': ['biology', 'environmental_science'],
            'neuroscience': ['biology', 'psychology', 'medicine'],
            'biochemistry': ['biology', 'chemistry'],
            'bioinformatics': ['biology', 'computer_science'],
            'materials_science': ['chemistry', 'engineering', 'physics'],
            'environmental_health': ['environmental_science', 'medicine']
        }

        # Strong domain indicators
        self.strong_indicators = {
            'biology': ['gene', 'cell', 'dna', 'protein', 'genome', 'evolution'],
            'medicine': ['patient', 'clinical', 'treatment', 'diagnosis', 'therapy', 'hospital'],
            'chemistry': ['molecule', 'reaction', 'compound', 'synthesis', 'catalyst'],
            'environmental_science': ['climate', 'ecosystem', 'pollution', 'conservation', 'habitat'],
            'psychology': ['behavior', 'cognitive', 'mental', 'brain', 'memory'],
            'physics': ['quantum', 'particle', 'energy', 'mechanics', 'relativity'],
            'materials_science': ['material', 'composite', 'ceramic', 'polymer', 'alloy']
        }

    def classify_topic_domains(self, topic_model, topics, model_type="bertopic",
                              topic_samples=None):
        """Enhanced domain classification with sample document context"""
        print(f" Performing enhanced domain classification for {model_type}...")

        topic_domain_mapping = {}
        unique_topics = set(topics) - {-1} if model_type == "bertopic" else set(topics)

        with tqdm(total=len(unique_topics), desc=f"Classifying {model_type} topics") as pbar:
            for topic_id in unique_topics:
                try:
                    # Get topic words based on model type
                    if model_type == "bertopic":
                        topic_words_data = topic_model.get_topic(topic_id)
                        if not topic_words_data:
                            pbar.update(1)
                            continue
                        all_topic_words = [word for word, _ in topic_words_data[:12]]
                    else:  # LDA - use fixed method
                        if hasattr(topic_model, 'get_topic_words'):
                            all_topic_words = topic_model.get_topic_words(topic_id)
                        else:
                            pbar.update(1)
                            continue

                    if not all_topic_words:
                        pbar.update(1)
                        continue

                    # Create topic text with optional sample context
                    topic_text = ' '.join(all_topic_words).lower()
                    if topic_samples and topic_id in topic_samples:
                        sample_text = ' '.join(topic_samples[topic_id])
                        topic_text += ' ' + sample_text[:500].lower()  # Add sample context

                    # Calculate enhanced domain scores
                    domain_scores = self._calculate_enhanced_domain_scores(all_topic_words, topic_text)

                    if not domain_scores:
                        pbar.update(1)
                        continue

                    # Resolve ambiguous domains with context
                    best_domain = self._resolve_enhanced_domains(domain_scores, all_topic_words, topic_text)

                    if best_domain == 'unknown':
                        pbar.update(1)
                        continue

                    # Calculate confidence with context
                    confidence = self._calculate_confidence_with_context(domain_scores, best_domain, topic_text)

                    # Store result
                    topic_domain_mapping[int(topic_id)] = {
                        'primary_domain': best_domain,
                        'confidence': float(confidence),
                        'all_domain_scores': {k: float(v) for k, v in domain_scores.items()},
                        'topic_keywords': all_topic_words[:8],
                        'model_type': model_type,
                        'classification_method': 'enhanced_keyword'
                    }

                    # Display classification
                    confidence_level = "HIGH" if confidence > 0.7 else "MEDIUM" if confidence > 0.4 else "LOW"
                    top_keywords = ', '.join(all_topic_words[:4])
                    print(f"   {model_type.upper()} Topic {topic_id:2d} → {best_domain:20s} "
                          f"(conf: {confidence:.3f} [{confidence_level}]) - {top_keywords}")

                except Exception as e:
                    print(f"    Error classifying topic {topic_id}: {str(e)}")

                pbar.update(1)

        self._print_enhanced_domain_analysis(topic_domain_mapping, model_type)
        return topic_domain_mapping

    def _calculate_enhanced_domain_scores(self, topic_words, topic_text):
        """Calculate enhanced domain scores with context"""
        domain_scores = {}

        for domain, keywords in self.domain_keywords.items():
            score = 0.1  # Base score
            matches = []

            # Position-based scoring
            for i, word in enumerate(topic_words[:10]):
                if word in keywords:
                    position_weight = max(0, (10 - i) / 10) * 2.0
                    score += 1.0 + position_weight
                    matches.append((word, i))

            # Contextual presence scoring
            for keyword in keywords:
                if keyword in topic_text and keyword not in [m[0] for m in matches]:
                    if ' ' in keyword:  # Multi-word terms
                        if keyword in topic_text:
                            score += 0.8
                    else:  # Single words
                        score += 0.3

            # Multi-match bonus
            if len(matches) >= 3:
                score *= 1.3
            elif len(matches) >= 2:
                score *= 1.15

            # Strong domain indicator bonus
            strong_matches = [word for word in topic_words[:4] if word in self.strong_indicators.get(domain, [])]
            if strong_matches:
                score *= 1.2

            # Context density bonus (more keywords in context = higher score)
            context_keywords = sum(1 for keyword in keywords if keyword in topic_text)
            if context_keywords > 5:
                score *= 1.1

            domain_scores[domain] = score

        return domain_scores

    def _resolve_enhanced_domains(self, domain_scores, topic_words, topic_text):
        """Resolve ambiguous domains with enhanced logic"""
        domain_scores = {k: v for k, v in domain_scores.items() if v > 0}

        if not domain_scores:
            return 'unknown'

        sorted_domains = sorted(domain_scores.items(), key=lambda x: x[1], reverse=True)

        if len(sorted_domains) < 2:
            return sorted_domains[0][0]

        best_domain, best_score = sorted_domains[0]
        second_domain, second_score = sorted_domains[1]

        # Clear winner (score difference > 50%)
        if best_score > second_score * 1.5:
            return best_domain

        # Check for domain relationships
        for relationship, related_domains in self.domain_relationships.items():
            if best_domain in related_domains and second_domain in related_domains:
                # Choose based on stronger indicators and context
                best_indicators = sum(1 for word in topic_words[:4]
                                    if word in self.strong_indicators.get(best_domain, []))
                second_indicators = sum(1 for word in topic_words[:4]
                                      if word in self.strong_indicators.get(second_domain, []))

                if best_indicators > second_indicators * 1.5:
                    return best_domain
                elif second_indicators > best_indicators * 1.5:
                    return second_domain

        # Check context for additional clues
        context_winner = self._check_context_for_domain(topic_text)
        if context_winner and context_winner in [best_domain, second_domain]:
            return context_winner

        # Default to highest score
        return best_domain

    def _check_context_for_domain(self, topic_text):
        """Check context for domain clues"""
        domain_context_indicators = {
            'biology': ['cell', 'organism', 'species', 'evolutionary'],
            'medicine': ['clinical', 'patient', 'treatment', 'diagnosis'],
            'chemistry': ['compound', 'reaction', 'synthesis', 'chemical'],
            'physics': ['quantum', 'particle', 'energy', 'field'],
            'computer_science': ['algorithm', 'data', 'network', 'computational'],
            'engineering': ['design', 'system', 'structural', 'mechanical']
        }

        scores = {}
        for domain, indicators in domain_context_indicators.items():
            score = sum(1 for indicator in indicators if indicator in topic_text)
            if score > 0:
                scores[domain] = score

        if scores:
            return max(scores.items(), key=lambda x: x[1])[0]
        return None

    def _calculate_confidence_with_context(self, domain_scores, best_domain, topic_text):
        """Calculate confidence with context consideration"""
        if best_domain not in domain_scores:
            return 0.0

        best_score = domain_scores[best_domain]
        total_score = sum(domain_scores.values())

        if total_score == 0:
            return 0.0

        base_confidence = best_score / total_score

        # Adjust based on score dominance
        sorted_scores = sorted(domain_scores.values(), reverse=True)
        if len(sorted_scores) > 1:
            dominance_ratio = sorted_scores[0] / sorted_scores[1]
            if dominance_ratio > 2.0:
                base_confidence = min(1.0, base_confidence * 1.3)
            elif dominance_ratio > 1.5:
                base_confidence = min(1.0, base_confidence * 1.15)

        # Adjust based on strong indicators in context
        strong_indicator_count = sum(1 for indicator in self.strong_indicators.get(best_domain, [])
                                   if indicator in topic_text)
        if strong_indicator_count >= 2:
            base_confidence = min(1.0, base_confidence * 1.1)

        return min(1.0, base_confidence)

    def _print_enhanced_domain_analysis(self, domain_mapping, model_type):
        """Print enhanced domain analysis"""
        if not domain_mapping:
            print(f" No topics to analyze for {model_type}")
            return

        domain_counts = Counter()
        confidence_sum = Counter()

        for mapping in domain_mapping.values():
            domain = mapping['primary_domain']
            domain_counts[domain] += 1
            confidence_sum[domain] += mapping['confidence']

        print(f"\n {model_type.upper()} ENHANCED DOMAIN DISTRIBUTION:")
        print("=" * 60)

        total_topics = len(domain_mapping)
        for domain, count in domain_counts.most_common():
            percentage = (count / total_topics) * 100
            avg_confidence = confidence_sum[domain] / count
            confidence_level = "HIGH" if avg_confidence > 0.7 else "MEDIUM" if avg_confidence > 0.5 else "LOW"
            print(f"   {domain:22s}: {count:2d} topics ({percentage:5.1f}%) "
                  f"[conf: {avg_confidence:.3f} - {confidence_level}]")


class EnhancedModelEvaluator:
    """Enhanced model evaluator with comprehensive metrics"""
    def __init__(self):
        self.results = {}

    def evaluate_enhanced_model(self, topic_model, documents, topics,
                               embeddings=None, model_type="bertopic",
                               domain_mapping=None):
        """Enhanced model evaluation with comprehensive metrics"""
        print(f" Performing enhanced {model_type} model evaluation...")

        evaluation = {}

        # Basic statistics
        if model_type == "bertopic":
            unique_topics = len(set(topics)) - (1 if -1 in topics else 0)
            outliers = np.sum(topics == -1)
        else:  # LDA
            unique_topics = len(set(topics))
            outliers = 0

        evaluation['basic_stats'] = {
            'total_documents': int(len(documents)),
            'topics_found': int(unique_topics),
            'outliers': int(outliers),
            'outlier_percentage': float((outliers / len(topics)) * 100) if len(topics) > 0 else 0.0
        }

        # Enhanced coherence calculation
        coherence = self._calculate_enhanced_coherence(topic_model, documents, topics, model_type)
        evaluation['coherence_score'] = float(coherence)

        # Topic quality metrics
        topic_quality = self._analyze_enhanced_topic_quality(topics, model_type)
        evaluation['topic_quality'] = topic_quality

        # Domain metrics if available
        if domain_mapping:
            domain_metrics = self._calculate_domain_metrics(domain_mapping)
            evaluation['domain_metrics'] = domain_metrics

        # Additional metrics
        additional_metrics = self._calculate_additional_metrics(topics, embeddings, model_type)
        evaluation['additional_metrics'] = additional_metrics

        # Overall score with enhanced weighting
        evaluation['overall_score'] = float(self._calculate_enhanced_overall_score(evaluation, model_type))
        evaluation['model_type'] = model_type

        self.results = evaluation
        self._print_enhanced_evaluation_results(model_type)

        return evaluation

    def _calculate_enhanced_coherence(self, topic_model, documents, topics, model_type):
        """Enhanced coherence calculation"""
        try:
            topic_words = []

            if model_type == "bertopic":
                for topic in set(topics):
                    if topic != -1:
                        words = topic_model.get_topic(topic)
                        if words and len(words) >= 3:
                            topic_words.append([word for word, _ in words[:8]])
            else:  # LDA
                if hasattr(topic_model, 'get_topic_words'):
                    for topic_id in set(topics):
                        words = topic_model.get_topic_words(topic_id)
                        if words and len(words) >= 3:
                            topic_words.append(words[:8])

            if len(topic_words) < 2:
                return 0.0

            # Tokenize documents
            tokenized_docs = [doc.split() for doc in documents if doc and len(doc.split()) > 10]

            if len(tokenized_docs) < 15:
                return 0.0

            # Calculate coherence
            dictionary = corpora.Dictionary(tokenized_docs)
            dictionary.filter_extremes(no_below=3, no_above=0.75)
            corpus = [dictionary.doc2bow(doc) for doc in tokenized_docs]

            if len(corpus) < 10:
                return 0.0

            coherence_model = CoherenceModel(
                topics=topic_words,
                texts=tokenized_docs,
                dictionary=dictionary,
                coherence='c_v',
                processes=1
            )

            coherence_score = coherence_model.get_coherence()
            return max(0.0, float(coherence_score))

        except Exception as e:
            print(f" Enhanced coherence calculation failed for {model_type}: {e}")
            return 0.0

    def _analyze_enhanced_topic_quality(self, topics, model_type):
        """Analyze enhanced topic quality metrics"""
        topic_counts = Counter(topics)

        if model_type == "bertopic" and -1 in topic_counts:
            valid_topics = {k: v for k, v in topic_counts.items() if k != -1}
        else:
            valid_topics = topic_counts

        if not valid_topics:
            return {
                'balance_score': 0.0,
                'size_variation': 0.0,
                'avg_topic_size': 0.0,
                'min_topic_size': 0.0,
                'max_topic_size': 0.0,
                'topic_diversity': 0.0,
                'topic_distribution_score': 0.0
            }

        sizes = list(valid_topics.values())

        # Calculate topic diversity
        total_docs = sum(sizes)
        if total_docs > 0:
            proportions = [size / total_docs for size in sizes]
            diversity = 1 - sum(p**2 for p in proportions)
        else:
            diversity = 0.0

        # Calculate distribution score (penalize extreme sizes)
        avg_size = np.mean(sizes)
        if avg_size > 0:
            size_variation = np.std(sizes) / avg_size
            distribution_score = 1 / (1 + size_variation)
        else:
            distribution_score = 0.0

        return {
            'balance_score': float(min(sizes) / max(sizes)) if max(sizes) > 0 else 0.0,
            'size_variation': float(np.std(sizes) / np.mean(sizes)) if np.mean(sizes) > 0 else 0.0,
            'avg_topic_size': float(np.mean(sizes)),
            'min_topic_size': float(min(sizes)),
            'max_topic_size': float(max(sizes)),
            'topic_diversity': float(diversity),
            'topic_distribution_score': float(distribution_score)
        }

    def _calculate_domain_metrics(self, domain_mapping):
        """Calculate domain-specific metrics"""
        if not domain_mapping:
            return {}

        domains = [mapping['primary_domain'] for mapping in domain_mapping.values()]
        confidences = [mapping['confidence'] for mapping in domain_mapping.values()]

        domain_counts = Counter(domains)

        return {
            'unique_domains': len(domain_counts),
            'domain_entropy': self._calculate_entropy(domains),
            'avg_confidence': float(np.mean(confidences)) if confidences else 0.0,
            'domain_distribution': {k: int(v) for k, v in domain_counts.items()}
        }

    def _calculate_entropy(self, items):
        """Calculate entropy of a distribution"""
        if not items:
            return 0.0

        counts = Counter(items)
        total = len(items)
        entropy = 0.0

        for count in counts.values():
            p = count / total
            entropy -= p * np.log2(p)

        return entropy / np.log2(len(counts)) if len(counts) > 1 else 0.0

    def _calculate_additional_metrics(self, topics, embeddings, model_type):
        """Calculate additional evaluation metrics"""
        try:
            if model_type == "bertopic":
                unique_topics = len(set(topics)) - (1 if -1 in topics else 0)
            else:
                unique_topics = len(set(topics))

            # Topic count score (penalize too few or too many topics)
            ideal_topics = 20  # Target number
            topic_count_score = 1 / (1 + abs(unique_topics - ideal_topics) / ideal_topics)

            return {
                'topic_count_score': float(topic_count_score),
                'unique_topics': int(unique_topics),
                'ideal_topics_deviation': float(abs(unique_topics - ideal_topics) / ideal_topics)
            }
        except:
            return {
                'topic_count_score': 0.0,
                'unique_topics': 0,
                'ideal_topics_deviation': 1.0
            }

    def _calculate_enhanced_overall_score(self, evaluation, model_type):
        """Calculate enhanced overall quality score"""
        basic = evaluation['basic_stats']
        coherence = evaluation['coherence_score']
        quality = evaluation['topic_quality']
        additional = evaluation['additional_metrics']

        if model_type == "bertopic":
            # BERTopic scoring
            score = (
                coherence * 0.35 +
                (1 - basic['outlier_percentage'] / 100) * 0.25 +
                quality['balance_score'] * 0.15 +
                quality['topic_distribution_score'] * 0.15 +
                additional['topic_count_score'] * 0.10
            )
        else:
            # LDA scoring
            score = (
                coherence * 0.40 +
                quality['balance_score'] * 0.20 +
                quality['topic_distribution_score'] * 0.20 +
                additional['topic_count_score'] * 0.20
            )

        # Add domain metrics if available
        if 'domain_metrics' in evaluation:
            domain = evaluation['domain_metrics']
            domain_score = (domain.get('avg_confidence', 0) +
                          (1 - domain.get('domain_entropy', 0))) / 2
            score = score * 0.8 + domain_score * 0.2

        return min(max(score, 0.0), 1.0)

    def _print_enhanced_evaluation_results(self, model_type):
        """Print enhanced evaluation results"""
        print(f"\n {model_type.upper()} ENHANCED MODEL EVALUATION:")
        print("=" * 60)

        basic = self.results['basic_stats']
        print(f"   • Documents: {basic['total_documents']}")
        print(f"   • Topics: {basic['topics_found']}")
        if model_type == "bertopic":
            print(f"   • Outliers: {basic['outliers']} ({basic['outlier_percentage']:.1f}%)")

        print(f"   • Coherence Score: {self.results['coherence_score']:.3f}")

        quality = self.results['topic_quality']
        print(f"   • Balance Score: {quality['balance_score']:.3f}")
        print(f"   • Topic Diversity: {quality['topic_diversity']:.3f}")
        print(f"   • Distribution Score: {quality['topic_distribution_score']:.3f}")
        print(f"   • Avg Topic Size: {quality['avg_topic_size']:.1f}")
        print(f"   • Size Range: {quality['min_topic_size']} - {quality['max_topic_size']}")

        additional = self.results['additional_metrics']
        print(f"   • Topic Count Score: {additional['topic_count_score']:.3f}")
        print(f"   • Topics from Ideal: {additional['ideal_topics_deviation']:.1%}")

        # Domain metrics if available
        if 'domain_metrics' in self.results:
            domain = self.results['domain_metrics']
            print(f"   • Unique Domains: {domain['unique_domains']}")
            print(f"   • Avg Domain Confidence: {domain.get('avg_confidence', 0):.3f}")
            print(f"   • Domain Entropy: {domain.get('domain_entropy', 0):.3f}")

        overall = self.results['overall_score']
        assessment = "EXCELLENT" if overall > 0.7 else "GOOD" if overall > 0.5 else "FAIR" if overall > 0.3 else "POOR"
        print(f"   • Overall Score: {overall:.3f} - {assessment}")


class EnhancedModelComparator:
    """Enhanced model comparator with comprehensive metrics"""
    def __init__(self):
        self.comparison_results = {}

    def compare_models_enhanced(self, bertopic_results, lda_results, documents):
        """Enhanced model comparison"""
        print("\n" + "="*70)
        print(" ENHANCED MODEL COMPARISON")
        print("="*70)

        # Extract topic samples for context
        bertopic_samples = self._extract_topic_samples(
            bertopic_results['topics'], documents, bertopic_results['domain_mapping']
        )
        lda_samples = self._extract_topic_samples(
            lda_results['topics'], documents, lda_results['domain_mapping']
        )

        # Build enhanced comparison
        comparison = {
            'bertopic': self._enhance_model_results(
                bertopic_results, "bertopic", bertopic_samples
            ),
            'lda': self._enhance_model_results(
                lda_results, "lda", lda_samples
            ),
            'comparison_timestamp': datetime.now().isoformat()
        }

        # Print enhanced comparison
        self._print_enhanced_comparison(comparison)

        # Save comparison
        self.comparison_results = comparison
        self._save_enhanced_comparison()

        return comparison

    def _extract_topic_samples(self, topics, documents, domain_mapping):
        """Extract sample documents for each topic"""
        samples = {}

        for topic_id in set(topics):
            if topic_id == -1 or topic_id not in domain_mapping:
                continue

            # Get indices of documents in this topic
            doc_indices = [i for i, t in enumerate(topics) if t == topic_id]

            # Sample documents
            sample_size = min(config.TOPIC_EVALUATION_SAMPLE_SIZE, len(doc_indices))
            if sample_size > 0:
                sample_indices = np.random.choice(doc_indices, size=sample_size, replace=False)
                samples[topic_id] = [documents[i] for i in sample_indices]

        return samples

    def _enhance_model_results(self, model_results, model_type, topic_samples):
        """Enhance model results with additional metrics"""
        enhanced = {
            'evaluation': model_results['evaluation'],
            'domain_mapping': model_results['domain_mapping'],
            'topics_found': len(model_results['domain_mapping']),
            'domains_covered': len(set(
                m['primary_domain'] for m in model_results['domain_mapping'].values()
            )),
            'avg_confidence': np.mean([
                m['confidence'] for m in model_results['domain_mapping'].values()
            ]),
            'model_type': model_type,
            'topic_samples_available': len(topic_samples)
        }

        # Add domain-specific metrics
        if model_results['domain_mapping']:
            domains = [m['primary_domain'] for m in model_results['domain_mapping'].values()]
            domain_counts = Counter(domains)
            enhanced['domain_distribution'] = dict(domain_counts.most_common())

            # Calculate domain purity (how concentrated domains are)
            if len(domain_counts) > 1:
                total = sum(domain_counts.values())
                max_domain = max(domain_counts.values())
                enhanced['domain_purity'] = max_domain / total
            else:
                enhanced['domain_purity'] = 1.0

        return enhanced

    def _print_enhanced_comparison(self, comparison):
        """Print enhanced comparison table"""

        bert = comparison['bertopic']
        lda = comparison['lda']

        print("\n ENHANCED PERFORMANCE COMPARISON:")
        print("-" * 100)
        print(f"{'Metric':<30} {'BERTopic':<15} {'LDA':<15} {'Winner':<10} {'Notes':<30}")
        print("-" * 100)

        metrics = [
            ('Coherence Score',
             bert['evaluation']['coherence_score'],
             lda['evaluation']['coherence_score'],
             'Higher is better'),

            ('Overall Quality Score',
             bert['evaluation']['overall_score'],
             lda['evaluation']['overall_score'],
             'Model evaluation score'),

            ('Topics Found',
             bert['topics_found'],
             lda['topics_found'],
             'Closer to ideal (20) is better'),

            ('Domains Covered',
             bert['domains_covered'],
             lda['domains_covered'],
             'Domain diversity'),

            ('Avg Confidence',
             bert['avg_confidence'],
             lda['avg_confidence'],
             'Classification confidence'),

            ('Topic Diversity',
             bert['evaluation']['topic_quality']['topic_diversity'],
             lda['evaluation']['topic_quality']['topic_diversity'],
             'Higher = more balanced'),

            ('Distribution Score',
             bert['evaluation']['topic_quality']['topic_distribution_score'],
             lda['evaluation']['topic_quality']['topic_distribution_score'],
             'Topic size distribution'),
        ]

        # Add domain metrics if available
        if 'domain_purity' in bert:
            metrics.extend([
                ('Domain Purity',
                 bert['domain_purity'],
                 lda['domain_purity'],
                 'Higher = more concentrated domains'),
            ])

        for metric_name, bert_value, lda_value, notes in metrics:
            # Format values
            if isinstance(bert_value, float):
                bert_display = f"{bert_value:.3f}"
                lda_display = f"{lda_value:.3f}"
            else:
                bert_display = str(bert_value)
                lda_display = str(lda_value)

            # Determine winner
            if isinstance(bert_value, (int, float)) and isinstance(lda_value, (int, float)):
                # For most metrics, higher is better
                if metric_name in ['Coherence Score', 'Overall Quality Score',
                                 'Avg Confidence', 'Topic Diversity', 'Distribution Score',
                                 'Domain Purity']:
                    winner = "BERTopic" if bert_value > lda_value else "LDA" if lda_value > bert_value else "Tie"
                elif metric_name == 'Topics Found':
                    # Closer to 20 is better
                    bert_dev = abs(bert_value - 20)
                    lda_dev = abs(lda_value - 20)
                    winner = "BERTopic" if bert_dev < lda_dev else "LDA" if lda_dev < bert_dev else "Tie"
                elif metric_name == 'Domains Covered':
                    # Balance is important - neither too few nor too many
                    bert_score = 1 / (1 + abs(bert_value - 6))  # Target 6 domains
                    lda_score = 1 / (1 + abs(lda_value - 6))
                    winner = "BERTopic" if bert_score > lda_score else "LDA" if lda_score > bert_score else "Tie"
                else:
                    winner = "N/A"
            else:
                winner = "N/A"

            print(f"{metric_name:<30} {bert_display:<15} {lda_display:<15} {winner:<10} {notes:<30}")

        print("-" * 100)

        # Overall assessment
        bert_quality = bert['evaluation']['overall_score']
        lda_quality = lda['evaluation']['overall_score']

        if bert_quality > lda_quality * 1.1:
            print(f"\n OVERALL ASSESSMENT: BERTopic significantly better")
            print(f"   • BERTopic advantage: +{(bert_quality/lda_quality - 1)*100:.1f}%")
            print(f"   • Recommendation: Use BERTopic for this dataset")
        elif lda_quality > bert_quality * 1.1:
            print(f"\n OVERALL ASSESSMENT: LDA significantly better")
            print(f"   • LDA advantage: +{(lda_quality/bert_quality - 1)*100:.1f}%")
            print(f"   • Recommendation: Use LDA for this dataset")
        else:
            print(f"\n OVERALL ASSESSMENT: Comparable performance")
            print(f"   • BERTopic: {bert_quality:.3f}, LDA: {lda_quality:.3f}")
            print(f"   • Recommendation: Consider computational requirements")

        # Domain distribution comparison
        print(f"\n DOMAIN DISTRIBUTION COMPARISON:")
        bert_domains = bert.get('domain_distribution', {})
        lda_domains = lda.get('domain_distribution', {})

        all_domains = set(bert_domains.keys()) | set(lda_domains.keys())
        if all_domains:
            print(f"{'Domain':<20} {'BERTopic':<10} {'LDA':<10}")
            print("-" * 45)
            for domain in sorted(all_domains):
                bert_count = bert_domains.get(domain, 0)
                lda_count = lda_domains.get(domain, 0)
                print(f"{domain:<20} {bert_count:<10} {lda_count:<10}")

    def _save_enhanced_comparison(self):
        """Save enhanced comparison results"""
        comparison_dir = config.COMPARISON_DIR

        # Save comparison data
        with open(f"{comparison_dir}/enhanced_comparison.json", 'w') as f:
            json.dump(self.comparison_results, f, indent=2, default=str)

        # Create enhanced visualizations
        self._create_enhanced_visualizations()

        print(f" Enhanced comparison results saved to: {comparison_dir}")

    def _create_enhanced_visualizations(self):
        """Create enhanced comparison visualizations"""
        bert = self.comparison_results['bertopic']
        lda = self.comparison_results['lda']

        # Prepare data for visualization
        metrics = ['Coherence', 'Overall', 'Topics*', 'Domains', 'Confidence', 'Diversity', 'Distribution']

        bert_scores = [
            bert['evaluation']['coherence_score'],
            bert['evaluation']['overall_score'],
            bert['topics_found'] / 20,  # Normalized
            min(bert['domains_covered'] / 8, 1.0),  # Normalized
            bert['avg_confidence'],
            bert['evaluation']['topic_quality']['topic_diversity'],
            bert['evaluation']['topic_quality']['topic_distribution_score']
        ]

        lda_scores = [
            lda['evaluation']['coherence_score'],
            lda['evaluation']['overall_score'],
            lda['topics_found'] / 20,
            min(lda['domains_covered'] / 8, 1.0),
            lda['avg_confidence'],
            lda['evaluation']['topic_quality']['topic_diversity'],
            lda['evaluation']['topic_quality']['topic_distribution_score']
        ]

        # Create radar chart
        fig = go.Figure()

        fig.add_trace(go.Scatterpolar(
            r=bert_scores,
            theta=metrics,
            fill='toself',
            name='BERTopic',
            line_color='blue',
            opacity=0.8
        ))

        fig.add_trace(go.Scatterpolar(
            r=lda_scores,
            theta=metrics,
            fill='toself',
            name='LDA',
            line_color='red',
            opacity=0.8
        ))

        fig.update_layout(
            polar=dict(
                radialaxis=dict(
                    visible=True,
                    range=[0, 1]
                )),
            showlegend=True,
            title="Enhanced Model Comparison: BERTopic vs LDA",
            title_font_size=16,
            legend=dict(
                yanchor="top",
                y=0.99,
                xanchor="left",
                x=0.01
            )
        )

        fig.write_html(f"{config.COMPARISON_DIR}/enhanced_comparison_radar.html")

        # Create comprehensive bar chart
        fig, ax = plt.subplots(figsize=(14, 8))
        x = np.arange(len(metrics))
        width = 0.35

        bars1 = ax.bar(x - width/2, bert_scores, width, label='BERTopic',
                      color='blue', alpha=0.7, edgecolor='black')
        bars2 = ax.bar(x + width/2, lda_scores, width, label='LDA',
                      color='red', alpha=0.7, edgecolor='black')

        ax.set_xlabel('Evaluation Metrics', fontsize=12)
        ax.set_ylabel('Normalized Scores', fontsize=12)
        ax.set_title('Enhanced Model Comparison: BERTopic vs LDA', fontsize=14, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels(metrics, rotation=45, ha='right')
        ax.legend()
        ax.grid(True, alpha=0.3)

        # Add value labels on bars
        for bars in [bars1, bars2]:
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                       f'{height:.2f}', ha='center', va='bottom', fontsize=9)

        plt.tight_layout()
        plt.savefig(f"{config.COMPARISON_DIR}/enhanced_comparison_bar.png",
                   dpi=300, bbox_inches='tight')
        plt.close()

        print(" Enhanced visualizations created")


class EnhancedDomainModelingPipeline:
    """Enhanced pipeline without LLM"""
    def __init__(self):
        self.data_processor = EnhancedDataProcessor()
        self.topic_modeler = EnhancedTopicModeler()
        self.fixed_lda = FixedLDAModel()
        self.domain_classifier = EnhancedDomainClassifier()
        self.evaluator = EnhancedModelEvaluator()
        self.comparator = EnhancedModelComparator()

    def run_enhanced_pipeline(self):
        """Run enhanced pipeline without LLM"""
        print(" STARTING ENHANCED DOMAIN MODELING PIPELINE")
        print("=" * 70)
        print(" Features:")
        print("   • Enhanced data processing")
        print("   • Fixed LDA implementation with proper topic extraction")
        print("   • Enhanced domain classification with context")
        print("   • Comprehensive model comparison")
        print("=" * 70)

        try:
            # Step 1: Enhanced Data Processing
            print("\n STEP 1: ENHANCED DATA PROCESSING")
            df = self.data_processor.load_and_process_data()
            documents = df['cleaned_text'].tolist()
            print(f" Processing {len(documents)} documents")

            # Step 2: Enhanced BERTopic Modeling
            print("\n STEP 2: ENHANCED BERTopic MODELING")
            bertopic_topics, bertopic_probabilities = self.topic_modeler.fit_enhanced_model(documents)

            # Step 3: Fixed LDA Modeling
            print("\n STEP 3: FIXED LDA MODELING")
            self.fixed_lda.fit(documents)
            lda_topics = self.fixed_lda.predict_topics(documents)
            lda_evaluation = self.fixed_lda.evaluate(documents)

            # Step 4: Enhanced Domain Classification
            print("\n STEP 4: ENHANCED DOMAIN CLASSIFICATION")

            # Extract topic samples for context
            bertopic_samples = self._extract_topic_samples(bertopic_topics, documents)
            lda_samples = self._extract_topic_samples(lda_topics, documents)

            print("   Classifying BERTopic topics with context...")
            bertopic_domain_mapping = self.domain_classifier.classify_topic_domains(
                self.topic_modeler.topic_model, bertopic_topics, "bertopic", bertopic_samples
            )

            print("\n   Classifying LDA topics with context...")
            lda_domain_mapping = self.domain_classifier.classify_topic_domains(
                self.fixed_lda, lda_topics, "lda", lda_samples
            )

            # Step 5: Enhanced Evaluation
            print("\n STEP 5: ENHANCED MODEL EVALUATION")

            bertopic_evaluation = self.evaluator.evaluate_enhanced_model(
                self.topic_modeler.topic_model, documents, bertopic_topics,
                self.topic_modeler.embeddings, "bertopic", bertopic_domain_mapping
            )

            lda_evaluation_enhanced = self.evaluator.evaluate_enhanced_model(
                self.fixed_lda, documents, lda_topics, None, "lda", lda_domain_mapping
            )

            # Step 6: Enhanced Model Comparison
            print("\n STEP 6: ENHANCED MODEL COMPARISON")

            bertopic_results = {
                'topics': bertopic_topics,
                'domain_mapping': bertopic_domain_mapping,
                'evaluation': bertopic_evaluation,
                'model': self.topic_modeler.topic_model
            }

            lda_results = {
                'topics': lda_topics,
                'domain_mapping': lda_domain_mapping,
                'evaluation': lda_evaluation_enhanced,
                'topic_model': self.fixed_lda
            }

            comparison_results = self.comparator.compare_models_enhanced(
                bertopic_results, lda_results, documents
            )

            # Step 7: Save Enhanced Results
            print("\n STEP 7: SAVING ENHANCED RESULTS")
            self._save_enhanced_results(df, bertopic_results, lda_results, comparison_results)

            # Step 8: Final Summary
            print("\n STEP 8: FINAL ENHANCED SUMMARY")
            self._print_enhanced_summary(bertopic_results, lda_results, comparison_results)

            print("\n ENHANCED PIPELINE COMPLETED SUCCESSFULLY!")

            return {
                'bertopic': bertopic_results,
                'lda': lda_results,
                'comparison': comparison_results,
                'documents': documents
            }

        except Exception as e:
            print(f" Enhanced pipeline execution failed: {e}")
            import traceback
            traceback.print_exc()
            return None

    def _extract_topic_samples(self, topics, documents):
        """Extract sample documents for each topic"""
        samples = {}

        for topic_id in set(topics):
            if topic_id == -1:
                continue

            # Get indices of documents in this topic
            doc_indices = [i for i, t in enumerate(topics) if t == topic_id]

            # Sample documents
            sample_size = min(config.TOPIC_EVALUATION_SAMPLE_SIZE, len(doc_indices))
            if sample_size > 0:
                sample_indices = np.random.choice(doc_indices, size=sample_size, replace=False)
                samples[topic_id] = [documents[i] for i in sample_indices]

        return samples

    def _save_enhanced_results(self, df, bertopic_results, lda_results, comparison_results):
        """Save enhanced results"""
        try:
            # Save BERTopic results
            bert_df = df.copy()
            bert_df['topic'] = bertopic_results['topics']
            bert_df['domain'] = bert_df['topic'].map(
                {k: v['primary_domain'] for k, v in bertopic_results['domain_mapping'].items()}
            )
            bert_df['confidence'] = bert_df['topic'].map(
                {k: v['confidence'] for k, v in bertopic_results['domain_mapping'].items()}
            )
            bert_df['classification_method'] = bert_df['topic'].map(
                {k: v.get('classification_method', 'enhanced_keyword') for k, v in bertopic_results['domain_mapping'].items()}
            )
            bert_df.to_csv(f"{config.OUTPUT_DIR}/enhanced_bertopic_assignments.csv", index=False)

            # Save LDA results
            lda_df = df.copy()
            lda_df['topic'] = lda_results['topics']
            lda_df['domain'] = lda_df['topic'].map(
                {k: v['primary_domain'] for k, v in lda_results['domain_mapping'].items()}
            )
            lda_df['confidence'] = lda_df['topic'].map(
                {k: v['confidence'] for k, v in lda_results['domain_mapping'].items()}
            )
            lda_df['classification_method'] = lda_df['topic'].map(
                {k: v.get('classification_method', 'enhanced_keyword') for k, v in lda_results['domain_mapping'].items()}
            )
            lda_df.to_csv(f"{config.COMPARISON_DIR}/enhanced_lda_assignments.csv", index=False)

            # Save model evaluations
            with open(f"{config.OUTPUT_DIR}/enhanced_evaluations.json", 'w') as f:
                json.dump({
                    'bertopic': bertopic_results['evaluation'],
                    'lda': lda_results['evaluation'],
                    'comparison': comparison_results
                }, f, indent=2, default=str)

            # Save topic models
            self.topic_modeler.topic_model.save(f"{config.OUTPUT_DIR}/enhanced_topic_model")

            # Save LDA topic words
            lda_topics_dict = self.fixed_lda.get_topics_dict(n_words=15)
            with open(f"{config.COMPARISON_DIR}/enhanced_lda_topics.json", 'w') as f:
                json.dump(lda_topics_dict, f, indent=2)

            print(f" Enhanced results saved successfully")
            print(f" BERTopic results: {config.OUTPUT_DIR}")
            print(f" LDA results: {config.COMPARISON_DIR}")

        except Exception as e:
            print(f" Error saving enhanced results: {e}")

    def _print_enhanced_summary(self, bertopic_results, lda_results, comparison_results):
        """Print enhanced summary"""
        print("\n" + "="*70)
        print(" ENHANCED PIPELINE SUMMARY")
        print("="*70)

        # BERTopic summary
        bert_domains = set(
            m['primary_domain'] for m in bertopic_results['domain_mapping'].values()
        )
        bert_avg_conf = np.mean([
            m['confidence'] for m in bertopic_results['domain_mapping'].values()
        ])

        print(f"\n BERTopic Enhanced Results:")
        print(f"   • Topics classified: {len(bertopic_results['domain_mapping'])}")
        print(f"   • Domains identified: {len(bert_domains)}")
        print(f"   • Domains: {', '.join(sorted(bert_domains))}")
        print(f"   • Average confidence: {bert_avg_conf:.3f}")
        print(f"   • Coherence score: {bertopic_results['evaluation']['coherence_score']:.3f}")
        print(f"   • Overall quality: {bertopic_results['evaluation']['overall_score']:.3f}")

        # LDA summary
        lda_domains = set(
            m['primary_domain'] for m in lda_results['domain_mapping'].values()
        )
        lda_avg_conf = np.mean([
            m['confidence'] for m in lda_results['domain_mapping'].values()
        ])

        print(f"\n LDA Enhanced Results:")
        print(f"   • Topics classified: {len(lda_results['domain_mapping'])}")
        print(f"   • Domains identified: {len(lda_domains)}")
        print(f"   • Domains: {', '.join(sorted(lda_domains))}")
        print(f"   • Average confidence: {lda_avg_conf:.3f}")
        print(f"   • Coherence score: {lda_results['evaluation']['coherence_score']:.3f}")
        print(f"   • Perplexity: {self.fixed_lda.perplexity_score:.2f}")
        print(f"   • Overall quality: {lda_results['evaluation']['overall_score']:.3f}")

        # Performance comparison
        print(f"\n PERFORMANCE COMPARISON:")
        bert_quality = bertopic_results['evaluation']['overall_score']
        lda_quality = lda_results['evaluation']['overall_score']

        if bert_quality > lda_quality:
            advantage = (bert_quality / lda_quality - 1) * 100
            print(f"   • BERTopic is {advantage:.1f}% better than LDA")
            print(f"   • Recommendation: Use BERTopic for this dataset")
        else:
            advantage = (lda_quality / bert_quality - 1) * 100
            print(f"   • LDA is {advantage:.1f}% better than BERTopic")
            print(f"   • Recommendation: Use LDA for this dataset")

        print(f"\n Next steps:")
        print(f"   1. Check {config.OUTPUT_DIR}/ for detailed BERTopic results")
        print(f"   2. Check {config.COMPARISON_DIR}/ for LDA results and comparison")
        print(f"   3. Review the enhanced comparison report for insights")


# =============================================================================
# EXECUTION
# =============================================================================

if __name__ == "__main__":
    # Configure logging
    logging.getLogger("bertopic").setLevel(logging.WARNING)
    logging.getLogger("umap").setLevel(logging.WARNING)
    logging.getLogger("hdbscan").setLevel(logging.WARNING)

    print("\n" + "="*70)
    print(" ENHANCED DOMAIN MODELING PIPELINE WITHOUT LLM")
    print("="*70)
    print(f"   • Embedding Models: {list(config.EMBEDDING_MODELS.values())}")
    print(f"   • Enhanced LDA with proper topic extraction")
    print(f"   • Enhanced domain classification with context")
    print(f"   • Comprehensive comparison framework")
    print("="*70)

    # Execute enhanced pipeline
    pipeline = EnhancedDomainModelingPipeline()
    results = pipeline.run_enhanced_pipeline()

    if results is not None:
        print(f"\n" + "="*70)
        print(" ENHANCED PIPELINE EXECUTION COMPLETE")
        print("="*70)
        print(f" Check the following directories for results:")
        print(f"   1. {config.OUTPUT_DIR}/ - BERTopic results and visualizations")
        print(f"   2. {config.COMPARISON_DIR}/ - LDA results and enhanced comparison")
        print(f"   3. Review the enhanced comparison report for model selection guidance")
    else:
        print("\n Enhanced pipeline execution failed")

In [None]:


class EnhancedVisualizationGenerator:
    """Enhanced visualization generator for topic modeling results"""

    def __init__(self):
        self.color_palette = plt.cm.Set3
        self.domain_colors = {
            'biology': '#2E86AB',
            'medicine': '#A23B72',
            'chemistry': '#F18F01',
            'environmental_science': '#73AB84',
            'computer_science': '#6D597A',
            'physics': '#EF476F',
            'engineering': '#118AB2',
            'materials_science': '#06D6A0',
            'psychology': '#FFD166',
            'unknown': '#999999'
        }

    def generate_all_visualizations(self, results_df, topic_model, domain_mapping,
                                  model_type="bertopic", output_dir=config.OUTPUT_DIR):
        """Generate all visualizations for a model"""
        print(f"\n Generating enhanced visualizations for {model_type}...")

        # Create visualization directory
        vis_dir = f"{output_dir}/visualizations/{model_type}"
        os.makedirs(vis_dir, exist_ok=True)

        # 1. Topic Distribution
        self._plot_topic_distribution(results_df, model_type, vis_dir)

        # 2. Domain Distribution
        self._plot_domain_distribution(results_df, model_type, vis_dir)

        # 3. Topic Words Heatmap
        if model_type == "bertopic":
            self._plot_bertopic_heatmap(topic_model, results_df, vis_dir)
        else:
            self._plot_lda_topic_words(domain_mapping, vis_dir)

        # 4. Topic Similarity Matrix
        if model_type == "bertopic":
            self._plot_topic_similarity(topic_model, vis_dir)

        # 5. Document Embedding Visualization
        if hasattr(topic_model, 'embeddings_') and model_type == "bertopic":
            self._plot_document_embeddings(topic_model, results_df, vis_dir)

        # 6. Topic Evolution/Over Time
        if 'date' in results_df.columns:
            self._plot_topic_evolution(results_df, vis_dir)

        # 7. Interactive Visualizations
        self._create_interactive_visualizations(results_df, domain_mapping, model_type, vis_dir)

        # 8. Word Clouds
        self._create_wordclouds(domain_mapping, model_type, vis_dir)

        # 9. Confidence Distribution
        self._plot_confidence_distribution(results_df, model_type, vis_dir)

        print(f" Visualizations saved to: {vis_dir}")

    def _plot_topic_distribution(self, results_df, model_type, vis_dir):
        """Plot topic distribution"""
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))

        # Topic count distribution
        topic_counts = results_df['topic'].value_counts()

        if model_type == "bertopic":
            # Separate outliers
            if -1 in topic_counts.index:
                outliers = topic_counts[-1]
                topic_counts = topic_counts.drop(-1)
                axes[0].bar(['Outliers'], [outliers], color='gray', alpha=0.7)

        # Plot topic sizes
        topics_sorted = topic_counts.sort_values(ascending=True)
        colors = [self.color_palette(i/len(topics_sorted)) for i in range(len(topics_sorted))]

        axes[0].barh(range(len(topics_sorted)), topics_sorted.values, color=colors, edgecolor='black')
        axes[0].set_yticks(range(len(topics_sorted)))
        axes[0].set_yticklabels([f"Topic {int(idx)}" for idx in topics_sorted.index])
        axes[0].set_xlabel('Number of Documents')
        axes[0].set_title(f'{model_type.upper()} Topic Distribution')
        axes[0].grid(True, alpha=0.3, axis='x')

        # Add document counts on bars
        for i, v in enumerate(topics_sorted.values):
            axes[0].text(v + max(topics_sorted.values)*0.01, i, str(v),
                        va='center', fontsize=9)

        # Topic size distribution histogram
        sizes = topics_sorted.values
        axes[1].hist(sizes, bins=min(20, len(sizes)), color='skyblue',
                    edgecolor='black', alpha=0.7)
        axes[1].axvline(np.mean(sizes), color='red', linestyle='--',
                       label=f'Mean: {np.mean(sizes):.1f}')
        axes[1].axvline(np.median(sizes), color='green', linestyle='--',
                       label=f'Median: {np.median(sizes):.1f}')
        axes[1].set_xlabel('Topic Size (Number of Documents)')
        axes[1].set_ylabel('Frequency')
        axes[1].set_title('Topic Size Distribution')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f'{vis_dir}/topic_distribution.png', dpi=300, bbox_inches='tight')
        plt.close()

    def _plot_domain_distribution(self, results_df, model_type, vis_dir):
        """Plot domain distribution"""
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))

        # Domain count
        domain_counts = results_df['domain'].value_counts()

        # Prepare colors
        domain_colors = [self.domain_colors.get(domain, '#999999')
                        for domain in domain_counts.index]

        # Pie chart
        axes[0].pie(domain_counts.values, labels=domain_counts.index,
                   colors=domain_colors, autopct='%1.1f%%', startangle=90)
        axes[0].axis('equal')
        axes[0].set_title(f'{model_type.upper()} Domain Distribution (Pie)')

        # Bar chart
        y_pos = np.arange(len(domain_counts))
        axes[1].barh(y_pos, domain_counts.values, color=domain_colors, edgecolor='black')
        axes[1].set_yticks(y_pos)
        axes[1].set_yticklabels(domain_counts.index)
        axes[1].set_xlabel('Number of Documents')
        axes[1].set_title(f'{model_type.upper()} Domain Distribution (Bar)')
        axes[1].grid(True, alpha=0.3, axis='x')

        # Add counts on bars
        for i, v in enumerate(domain_counts.values):
            axes[1].text(v + max(domain_counts.values)*0.01, i, str(v),
                        va='center', fontsize=9)

        plt.tight_layout()
        plt.savefig(f'{vis_dir}/domain_distribution.png', dpi=300, bbox_inches='tight')
        plt.close()

        # Create detailed domain-topic matrix
        if 'topic' in results_df.columns:
            self._plot_domain_topic_matrix(results_df, model_type, vis_dir)

    def _plot_domain_topic_matrix(self, results_df, model_type, vis_dir):
        """Plot domain-topic matrix"""
        # Create cross-tabulation
        cross_tab = pd.crosstab(results_df['domain'], results_df['topic'])

        # Remove outliers if present
        if -1 in cross_tab.columns:
            cross_tab = cross_tab.drop(columns=[-1])

        plt.figure(figsize=(max(10, cross_tab.shape[1]*0.8),
                          max(6, cross_tab.shape[0]*0.6)))

        sns.heatmap(cross_tab, annot=True, fmt='d', cmap='YlOrRd',
                   cbar_kws={'label': 'Number of Documents'})
        plt.title(f'{model_type.upper()} Domain-Topic Matrix')
        plt.xlabel('Topic ID')
        plt.ylabel('Domain')
        plt.tight_layout()
        plt.savefig(f'{vis_dir}/domain_topic_matrix.png', dpi=300, bbox_inches='tight')
        plt.close()

    def _plot_bertopic_heatmap(self, topic_model, results_df, vis_dir):
        """Plot BERTopic heatmap"""
        try:
            # Get topic info
            topic_info = topic_model.get_topic_info()
            valid_topics = topic_info[topic_info['Topic'] != -1]

            if len(valid_topics) < 2:
                return

            # Create similarity matrix
            topics = valid_topics['Topic'].tolist()
            topic_words = {}

            for topic in topics:
                words = topic_model.get_topic(topic)
                if words:
                    topic_words[topic] = [word for word, _ in words[:5]]

            # Create figure
            fig, ax = plt.subplots(figsize=(max(10, len(topics)), max(8, len(topics))))

            # Create grid for displaying topic words
            ax.set_xlim(0, len(topics))
            ax.set_ylim(0, len(topics))
            ax.invert_yaxis()

            # Add topic labels
            for i, topic in enumerate(topics):
                ax.text(i + 0.5, -0.5, f"Topic {topic}",
                       ha='center', va='top', rotation=45, fontsize=9)
                ax.text(-0.5, i + 0.5, f"Topic {topic}",
                       ha='right', va='center', fontsize=9)

                # Add topic words in diagonal
                if topic in topic_words:
                    words_str = "\n".join(topic_words[topic][:3])
                    ax.text(i + 0.5, i + 0.5, words_str,
                           ha='center', va='center', fontsize=8,
                           bbox=dict(boxstyle="round,pad=0.3",
                                   facecolor='lightblue', alpha=0.7))

            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title('BERTopic Topics with Top Words')
            ax.grid(False)

            plt.tight_layout()
            plt.savefig(f'{vis_dir}/topic_words_heatmap.png', dpi=300, bbox_inches='tight')
            plt.close()

        except Exception as e:
            print(f" Could not create BERTopic heatmap: {e}")

    def _plot_lda_topic_words(self, domain_mapping, vis_dir):
        """Plot LDA topic words"""
        if not domain_mapping:
            return

        # Group topics by domain
        domain_topics = {}
        for topic_id, mapping in domain_mapping.items():
            domain = mapping['primary_domain']
            if domain not in domain_topics:
                domain_topics[domain] = []
            domain_topics[domain].append((topic_id, mapping['topic_keywords']))

        # Create visualization
        fig, axes = plt.subplots(len(domain_topics), 1,
                               figsize=(12, 4 * len(domain_topics)))

        if len(domain_topics) == 1:
            axes = [axes]

        for idx, (domain, topics) in enumerate(domain_topics.items()):
            ax = axes[idx] if len(domain_topics) > 1 else axes

            # Prepare data
            topic_ids = [t[0] for t in topics]
            topic_words = [', '.join(t[1][:4]) for t in topics]

            # Create bar chart
            y_pos = np.arange(len(topic_ids))
            colors = [self.domain_colors.get(domain, '#999999')] * len(topic_ids)

            bars = ax.barh(y_pos, [1] * len(topic_ids), color=colors, edgecolor='black')
            ax.set_yticks(y_pos)
            ax.set_yticklabels([f"Topic {tid}" for tid in topic_ids])
            ax.set_xlim(0, 1.2)
            ax.set_xlabel('')
            ax.set_title(f'Domain: {domain} - Topics {len(topic_ids)}')
            ax.grid(True, alpha=0.3, axis='x')

            # Add topic words as text
            for i, (bar, words) in enumerate(zip(bars, topic_words)):
                width = bar.get_width()
                ax.text(width + 0.02, bar.get_y() + bar.get_height()/2,
                       words, va='center', fontsize=9)

        plt.tight_layout()
        plt.savefig(f'{vis_dir}/lda_topic_words_by_domain.png',
                   dpi=300, bbox_inches='tight')
        plt.close()

    def _plot_topic_similarity(self, topic_model, vis_dir):
        """Plot topic similarity matrix"""
        try:
            # Get similarity matrix
            similarity_matrix = topic_model.calculate_topic_similarity()

            plt.figure(figsize=(10, 8))
            sns.heatmap(similarity_matrix, cmap='coolwarm', center=0,
                       square=True, cbar_kws={'label': 'Similarity'})
            plt.title('Topic Similarity Matrix')
            plt.xlabel('Topic ID')
            plt.ylabel('Topic ID')
            plt.tight_layout()
            plt.savefig(f'{vis_dir}/topic_similarity_matrix.png',
                       dpi=300, bbox_inches='tight')
            plt.close()

        except Exception as e:
            print(f" Could not create similarity matrix: {e}")

    def _plot_document_embeddings(self, topic_model, results_df, vis_dir):
        """Plot document embeddings in 2D"""
        try:
            if not hasattr(topic_model, 'embeddings_') or topic_model.embeddings_ is None:
                return

            # Reduce to 2D using UMAP
            umap_model = UMAP(n_components=2, random_state=42)
            embeddings_2d = umap_model.fit_transform(topic_model.embeddings_)

            # Create scatter plot
            plt.figure(figsize=(12, 10))

            # Color by domain if available
            if 'domain' in results_df.columns:
                unique_domains = results_df['domain'].unique()
                color_map = {domain: self.domain_colors.get(domain, '#999999')
                           for domain in unique_domains}

                for domain in unique_domains:
                    mask = results_df['domain'] == domain
                    plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
                              color=color_map[domain], label=domain,
                              alpha=0.6, s=20)

                plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            else:
                # Color by topic
                plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1],
                          c=results_df['topic'], cmap='tab20',
                          alpha=0.6, s=20)
                plt.colorbar(label='Topic ID')

            plt.title('Document Embeddings (2D UMAP projection)')
            plt.xlabel('UMAP 1')
            plt.ylabel('UMAP 2')
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(f'{vis_dir}/document_embeddings.png',
                       dpi=300, bbox_inches='tight')
            plt.close()

        except Exception as e:
            print(f" Could not create embeddings plot: {e}")

    def _plot_topic_evolution(self, results_df, vis_dir):
        """Plot topic evolution over time"""
        try:
            if 'date' not in results_df.columns:
                return

            # Convert date column
            results_df['date'] = pd.to_datetime(results_df['date'], errors='coerce')
            results_df = results_df.dropna(subset=['date'])

            # Group by month and topic
            results_df['year_month'] = results_df['date'].dt.to_period('M')

            # Create pivot table
            pivot = results_df.pivot_table(
                index='year_month',
                columns='topic',
                values='cleaned_text',
                aggfunc='count',
                fill_value=0
            )

            # Remove outliers column if present
            if -1 in pivot.columns:
                pivot = pivot.drop(columns=[-1])

            # Plot
            plt.figure(figsize=(14, 8))
            pivot.plot(kind='area', alpha=0.7, stacked=True, cmap='tab20', ax=plt.gca())
            plt.title('Topic Evolution Over Time')
            plt.xlabel('Time')
            plt.ylabel('Number of Documents')
            plt.legend(title='Topic ID', bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(f'{vis_dir}/topic_evolution.png',
                       dpi=300, bbox_inches='tight')
            plt.close()

        except Exception as e:
            print(f" Could not create topic evolution plot: {e}")

    def _create_interactive_visualizations(self, results_df, domain_mapping, model_type, vis_dir):
        """Create interactive visualizations using Plotly"""

        # 1. Interactive topic-distribution bar chart
        topic_counts = results_df['topic'].value_counts().reset_index()
        topic_counts.columns = ['topic', 'count']

        fig = px.bar(topic_counts, x='topic', y='count',
                    title=f'{model_type.upper()} Topic Distribution',
                    labels={'topic': 'Topic ID', 'count': 'Number of Documents'},
                    color='count', color_continuous_scale='viridis')

        fig.write_html(f'{vis_dir}/interactive_topic_distribution.html')

        # 2. Interactive domain distribution
        if 'domain' in results_df.columns:
            domain_counts = results_df['domain'].value_counts().reset_index()
            domain_counts.columns = ['domain', 'count']

            fig = px.pie(domain_counts, values='count', names='domain',
                        title=f'{model_type.upper()} Domain Distribution',
                        color_discrete_map=self.domain_colors)

            fig.write_html(f'{vis_dir}/interactive_domain_distribution.html')

            # 3. Interactive scatter plot if embeddings available
            if 'x' in results_df.columns and 'y' in results_df.columns:
                fig = px.scatter(results_df, x='x', y='y', color='domain',
                               hover_data=['topic', 'confidence'],
                               title=f'{model_type.upper()} Document Clusters',
                               color_discrete_map=self.domain_colors)

                fig.write_html(f'{vis_dir}/interactive_document_clusters.html')

    def _create_wordclouds(self, domain_mapping, model_type, vis_dir):
        """Create word clouds for each domain"""
        try:
            # Group topic keywords by domain
            domain_keywords = {}

            for topic_id, mapping in domain_mapping.items():
                domain = mapping['primary_domain']
                keywords = mapping.get('topic_keywords', [])

                if domain not in domain_keywords:
                    domain_keywords[domain] = []

                domain_keywords[domain].extend(keywords)

            # Create word cloud for each domain
            for domain, keywords in domain_keywords.items():
                if not keywords:
                    continue

                # Create frequency dictionary
                freq_dict = Counter(keywords)

                # Generate word cloud
                wordcloud = WordCloud(
                    width=800,
                    height=400,
                    background_color='white',
                    colormap='tab20c',
                    max_words=50
                ).generate_from_frequencies(freq_dict)

                # Plot
                plt.figure(figsize=(12, 6))
                plt.imshow(wordcloud, interpolation='bilinear')
                plt.title(f'{model_type.upper()} - Domain: {domain}', fontsize=16)
                plt.axis('off')
                plt.tight_layout()

                # Save
                safe_domain = domain.replace(' ', '_').lower()
                plt.savefig(f'{vis_dir}/wordcloud_{safe_domain}.png',
                           dpi=300, bbox_inches='tight')
                plt.close()

        except Exception as e:
            print(f" Could not create word clouds: {e}")

    def _plot_confidence_distribution(self, results_df, model_type, vis_dir):
        """Plot confidence score distribution"""
        if 'confidence' not in results_df.columns:
            return

        fig, axes = plt.subplots(1, 2, figsize=(14, 5))

        # Histogram of confidence scores
        axes[0].hist(results_df['confidence'].dropna(), bins=20,
                    color='lightcoral', edgecolor='black', alpha=0.7)
        axes[0].axvline(results_df['confidence'].mean(), color='red',
                       linestyle='--', label=f'Mean: {results_df["confidence"].mean():.3f}')
        axes[0].axvline(results_df['confidence'].median(), color='blue',
                       linestyle='--', label=f'Median: {results_df["confidence"].median():.3f}')
        axes[0].set_xlabel('Confidence Score')
        axes[0].set_ylabel('Frequency')
        axes[0].set_title(f'{model_type.upper()} Confidence Distribution')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

        # Box plot by domain
        if 'domain' in results_df.columns:
            data_to_plot = []
            domains = []

            for domain in results_df['domain'].unique():
                domain_data = results_df[results_df['domain'] == domain]['confidence']
                if len(domain_data) > 0:
                    data_to_plot.append(domain_data)
                    domains.append(domain)

            if data_to_plot:
                axes[1].boxplot(data_to_plot, labels=domains)
                axes[1].set_xticklabels(domains, rotation=45, ha='right')
                axes[1].set_ylabel('Confidence Score')
                axes[1].set_title('Confidence by Domain')
                axes[1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f'{vis_dir}/confidence_distribution.png',
                   dpi=300, bbox_inches='tight')
        plt.close()

    def create_comparison_dashboard(self, bertopic_results, lda_results,
                                  bertopic_df, lda_df, output_dir=config.COMPARISON_DIR):
        """Create comparison dashboard for both models"""
        print("\n Creating comparison dashboard...")

        vis_dir = f"{output_dir}/comparison_visualizations"
        os.makedirs(vis_dir, exist_ok=True)

        # 1. Side-by-side comparison
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))

        # Topic count comparison
        bert_topic_counts = bertopic_df['topic'].value_counts()
        lda_topic_counts = lda_df['topic'].value_counts()

        # Remove outliers
        if -1 in bert_topic_counts.index:
            bert_topic_counts = bert_topic_counts.drop(-1)

        axes[0, 0].bar(['BERTopic', 'LDA'],
                      [len(bert_topic_counts), len(lda_topic_counts)],
                      color=['blue', 'red'], alpha=0.7)
        axes[0, 0].set_ylabel('Number of Topics')
        axes[0, 0].set_title('Topic Count Comparison')
        axes[0, 0].grid(True, alpha=0.3, axis='y')

        # Add value labels
        for i, v in enumerate([len(bert_topic_counts), len(lda_topic_counts)]):
            axes[0, 0].text(i, v + 0.1, str(v), ha='center', va='bottom')

        # Domain count comparison
        bert_domain_counts = bertopic_df['domain'].nunique()
        lda_domain_counts = lda_df['domain'].nunique()

        axes[0, 1].bar(['BERTopic', 'LDA'],
                      [bert_domain_counts, lda_domain_counts],
                      color=['blue', 'red'], alpha=0.7)
        axes[0, 1].set_ylabel('Number of Domains')
        axes[0, 1].set_title('Domain Count Comparison')
        axes[0, 1].grid(True, alpha=0.3, axis='y')

        # Add value labels
        for i, v in enumerate([bert_domain_counts, lda_domain_counts]):
            axes[0, 1].text(i, v + 0.1, str(v), ha='center', va='bottom')

        # Average confidence comparison
        bert_avg_conf = bertopic_df['confidence'].mean()
        lda_avg_conf = lda_df['confidence'].mean()

        axes[0, 2].bar(['BERTopic', 'LDA'],
                      [bert_avg_conf, lda_avg_conf],
                      color=['blue', 'red'], alpha=0.7)
        axes[0, 2].set_ylabel('Average Confidence')
        axes[0, 2].set_title('Confidence Comparison')
        axes[0, 2].grid(True, alpha=0.3, axis='y')

        # Add value labels
        for i, v in enumerate([bert_avg_conf, lda_avg_conf]):
            axes[0, 2].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom')

        # Topic size distribution comparison
        bert_sizes = bert_topic_counts.values
        lda_sizes = lda_topic_counts.values

        axes[1, 0].hist([bert_sizes, lda_sizes], bins=15,
                       label=['BERTopic', 'LDA'],
                       color=['blue', 'red'], alpha=0.7)
        axes[1, 0].set_xlabel('Topic Size')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].set_title('Topic Size Distribution Comparison')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)

        # Domain overlap analysis
        bert_domains = set(bertopic_df['domain'].unique())
        lda_domains = set(lda_df['domain'].unique())

        overlap = bert_domains.intersection(lda_domains)
        bert_only = bert_domains - lda_domains
        lda_only = lda_domains - bert_domains

        # Create Venn diagram-like visualization
        axes[1, 1].text(0.3, 0.7, f"BERTopic only:\n{len(bert_only)}",
                       ha='center', va='center', fontsize=10,
                       bbox=dict(boxstyle="round,pad=0.3", facecolor='blue', alpha=0.3))
        axes[1, 1].text(0.7, 0.7, f"LDA only:\n{len(lda_only)}",
                       ha='center', va='center', fontsize=10,
                       bbox=dict(boxstyle="round,pad=0.3", facecolor='red', alpha=0.3))
        axes[1, 1].text(0.5, 0.3, f"Overlap:\n{len(overlap)}",
                       ha='center', va='center', fontsize=10,
                       bbox=dict(boxstyle="round,pad=0.3", facecolor='purple', alpha=0.3))
        axes[1, 1].set_xlim(0, 1)
        axes[1, 1].set_ylim(0, 1)
        axes[1, 1].set_title('Domain Overlap Analysis')
        axes[1, 1].axis('off')

        # Performance metrics comparison
        metrics = ['Coherence', 'Topic\nDiversity', 'Distribution\nScore']
        bert_scores = [
            bertopic_results['evaluation']['coherence_score'],
            bertopic_results['evaluation']['topic_quality']['topic_diversity'],
            bertopic_results['evaluation']['topic_quality']['topic_distribution_score']
        ]
        lda_scores = [
            lda_results['evaluation']['coherence_score'],
            lda_results['evaluation']['topic_quality']['topic_diversity'],
            lda_results['evaluation']['topic_quality']['topic_distribution_score']
        ]

        x = np.arange(len(metrics))
        width = 0.35

        axes[1, 2].bar(x - width/2, bert_scores, width, label='BERTopic',
                      color='blue', alpha=0.7)
        axes[1, 2].bar(x + width/2, lda_scores, width, label='LDA',
                      color='red', alpha=0.7)
        axes[1, 2].set_ylabel('Score')
        axes[1, 2].set_title('Quality Metrics Comparison')
        axes[1, 2].set_xticks(x)
        axes[1, 2].set_xticklabels(metrics)
        axes[1, 2].legend()
        axes[1, 2].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f'{vis_dir}/model_comparison_dashboard.png',
                   dpi=300, bbox_inches='tight')
        plt.close()

        # 2. Interactive comparison dashboard
        self._create_interactive_comparison_dashboard(
            bertopic_results, lda_results, bertopic_df, lda_df, vis_dir
        )

        print(f" Comparison dashboard saved to: {vis_dir}")

    def _create_interactive_comparison_dashboard(self, bertopic_results, lda_results,
                                                bertopic_df, lda_df, vis_dir):
        """Create interactive comparison dashboard"""

        # Create DataFrame for metrics comparison
        metrics_data = {
            'Metric': ['Topics Found', 'Domains Covered', 'Avg Confidence',
                      'Coherence Score', 'Topic Diversity', 'Distribution Score',
                      'Outliers (%)' if -1 in bertopic_df['topic'].values else 'N/A'],
            'BERTopic': [
                len(bertopic_df['topic'].unique()) - (1 if -1 in bertopic_df['topic'].values else 0),
                bertopic_df['domain'].nunique(),
                bertopic_df['confidence'].mean(),
                bertopic_results['evaluation']['coherence_score'],
                bertopic_results['evaluation']['topic_quality']['topic_diversity'],
                bertopic_results['evaluation']['topic_quality']['topic_distribution_score'],
                (bertopic_df['topic'] == -1).mean() * 100 if -1 in bertopic_df['topic'].values else 0
            ],
            'LDA': [
                len(lda_df['topic'].unique()),
                lda_df['domain'].nunique(),
                lda_df['confidence'].mean(),
                lda_results['evaluation']['coherence_score'],
                lda_results['evaluation']['topic_quality']['topic_diversity'],
                lda_results['evaluation']['topic_quality']['topic_distribution_score'],
                0  # LDA has no outliers
            ]
        }

        df_metrics = pd.DataFrame(metrics_data)

        # Create interactive bar chart
        fig = go.Figure()

        fig.add_trace(go.Bar(
            x=df_metrics['Metric'],
            y=df_metrics['BERTopic'],
            name='BERTopic',
            marker_color='blue',
            text=df_metrics['BERTopic'].round(3),
            textposition='auto',
        ))

        fig.add_trace(go.Bar(
            x=df_metrics['Metric'],
            y=df_metrics['LDA'],
            name='LDA',
            marker_color='red',
            text=df_metrics['LDA'].round(3),
            textposition='auto',
        ))

        fig.update_layout(
            title='Model Comparison Dashboard',
            xaxis_tickangle=-45,
            barmode='group',
            yaxis_title='Score/Value',
            height=600,
            showlegend=True
        )

        fig.write_html(f'{vis_dir}/interactive_comparison_dashboard.html')

        # Create radar chart comparison
        fig = go.Figure()

        metrics_for_radar = ['Coherence Score', 'Topic Diversity', 'Distribution Score',
                           'Avg Confidence', 'Domain Coverage']

        bert_values = [
            bertopic_results['evaluation']['coherence_score'],
            bertopic_results['evaluation']['topic_quality']['topic_diversity'],
            bertopic_results['evaluation']['topic_quality']['topic_distribution_score'],
            bertopic_df['confidence'].mean(),
            min(bertopic_df['domain'].nunique() / 8, 1.0)  # Normalized
        ]

        lda_values = [
            lda_results['evaluation']['coherence_score'],
            lda_results['evaluation']['topic_quality']['topic_diversity'],
            lda_results['evaluation']['topic_quality']['topic_distribution_score'],
            lda_df['confidence'].mean(),
            min(lda_df['domain'].nunique() / 8, 1.0)  # Normalized
        ]

        fig.add_trace(go.Scatterpolar(
            r=bert_values,
            theta=metrics_for_radar,
            fill='toself',
            name='BERTopic',
            line_color='blue'
        ))

        fig.add_trace(go.Scatterpolar(
            r=lda_values,
            theta=metrics_for_radar,
            fill='toself',
            name='LDA',
            line_color='red'
        ))

        fig.update_layout(
            polar=dict(
                radialaxis=dict(
                    visible=True,
                    range=[0, 1]
                )),
            showlegend=True,
            title='Quality Metrics Radar Comparison'
        )

        fig.write_html(f'{vis_dir}/quality_metrics_radar.html')




# Add this method to your EnhancedDomainModelingPipeline class:

def generate_all_visualizations_in_pipeline(self):
    """Generate all visualizations in the pipeline"""
    print("\n STEP 9: GENERATING ENHANCED VISUALIZATIONS")

    # Initialize visualization generator
    vis_generator = EnhancedVisualizationGenerator()

    # Load results
    bert_df = pd.read_csv(f"{config.OUTPUT_DIR}/enhanced_bertopic_assignments.csv")
    lda_df = pd.read_csv(f"{config.COMPARISON_DIR}/enhanced_lda_assignments.csv")

    # Load evaluations
    with open(f"{config.OUTPUT_DIR}/enhanced_evaluations.json", 'r') as f:
        evaluations = json.load(f)

    # Load domain mappings
    bert_domain_mapping = {}
    lda_domain_mapping = {}

    # Reconstruct domain mappings from saved data
    for _, row in bert_df.iterrows():
        if row['topic'] not in bert_domain_mapping:
            bert_domain_mapping[row['topic']] = {
                'primary_domain': row['domain'],
                'confidence': row['confidence'],
                'topic_keywords': []  # Would need to load from topic model
            }

    for _, row in lda_df.iterrows():
        if row['topic'] not in lda_domain_mapping:
            lda_domain_mapping[row['topic']] = {
                'primary_domain': row['domain'],
                'confidence': row['confidence'],
                'topic_keywords': []  # Would need to load from LDA topics
            }

    # Generate BERTopic visualizations
    print(" Generating BERTopic visualizations...")
    vis_generator.generate_all_visualizations(
        bert_df,
        self.topic_modeler.topic_model,
        bert_domain_mapping,
        model_type="bertopic",
        output_dir=config.OUTPUT_DIR
    )

    # Generate LDA visualizations
    print(" Generating LDA visualizations...")
    vis_generator.generate_all_visualizations(
        lda_df,
        self.fixed_lda,
        lda_domain_mapping,
        model_type="lda",
        output_dir=config.COMPARISON_DIR
    )

    # Generate comparison dashboard
    print(" Generating comparison dashboard...")
    vis_generator.create_comparison_dashboard(
        bertopic_results={
            'evaluation': evaluations['bertopic'],
            'domain_mapping': bert_domain_mapping
        },
        lda_results={
            'evaluation': evaluations['lda'],
            'domain_mapping': lda_domain_mapping
        },
        bertopic_df=bert_df,
        lda_df=lda_df,
        output_dir=config.COMPARISON_DIR
    )

    print(" Enhanced visualizations generated successfully!")



# Update the __main__ section to include visualizations:

if __name__ == "__main__":
    # Configure logging
    logging.getLogger("bertopic").setLevel(logging.WARNING)
    logging.getLogger("umap").setLevel(logging.WARNING)
    logging.getLogger("hdbscan").setLevel(logging.WARNING)

    print("\n" + "="*70)
    print(" ENHANCED DOMAIN MODELING PIPELINE WITHOUT LLM")
    print("="*70)
    print(f"   • Embedding Models: {list(config.EMBEDDING_MODELS.values())}")
    print(f"   • Enhanced LDA with proper topic extraction")
    print(f"   • Enhanced domain classification with context")
    print(f"   • Comprehensive comparison framework")
    print(f"   • Enhanced visualizations and dashboards")
    print("="*70)

    # Execute enhanced pipeline
    pipeline = EnhancedDomainModelingPipeline()
    results = pipeline.run_enhanced_pipeline()

    # Add visualization step
    if results is not None:
        # Generate visualizations
        try:
            # Create visualization generator
            vis_generator = EnhancedVisualizationGenerator()

            # Generate visualizations for both models
            print("\n" + "="*70)
            print(" GENERATING ENHANCED VISUALIZATIONS")
            print("="*70)

            # Prepare dataframes
            bert_df = pd.DataFrame({
                'topic': results['bertopic']['topics'],
                'domain': [results['bertopic']['domain_mapping'].get(t, {}).get('primary_domain', 'unknown')
                          for t in results['bertopic']['topics']],
                'confidence': [results['bertopic']['domain_mapping'].get(t, {}).get('confidence', 0.0)
                             for t in results['bertopic']['topics']],
                'cleaned_text': results['documents']
            })

            lda_df = pd.DataFrame({
                'topic': results['lda']['topics'],
                'domain': [results['lda']['domain_mapping'].get(t, {}).get('primary_domain', 'unknown')
                          for t in results['lda']['topics']],
                'confidence': [results['lda']['domain_mapping'].get(t, {}).get('confidence', 0.0)
                             for t in results['lda']['topics']],
                'cleaned_text': results['documents']
            })

            # Generate BERTopic visualizations
            print(" Generating BERTopic visualizations...")
            vis_generator.generate_all_visualizations(
                bert_df,
                results['bertopic']['model'],
                results['bertopic']['domain_mapping'],
                model_type="bertopic",
                output_dir=config.OUTPUT_DIR
            )

            # Generate LDA visualizations
            print(" Generating LDA visualizations...")
            vis_generator.generate_all_visualizations(
                lda_df,
                results['lda']['topic_model'],
                results['lda']['domain_mapping'],
                model_type="lda",
                output_dir=config.COMPARISON_DIR
            )

            # Generate comparison dashboard
            print(" Generating comparison dashboard...")
            vis_generator.create_comparison_dashboard(
                results['bertopic'],
                results['lda'],
                bert_df,
                lda_df,
                config.COMPARISON_DIR
            )

            print("\n Enhanced visualizations generated successfully!")

        except Exception as e:
            print(f" Visualization generation failed: {e}")
            import traceback
            traceback.print_exc()

        print(f"\n" + "="*70)
        print(" ENHANCED PIPELINE EXECUTION COMPLETE")
        print("="*70)
        print(f" Check the following directories for results:")
        print(f"   1. {config.OUTPUT_DIR}/ - BERTopic results and visualizations")
        print(f"   2. {config.COMPARISON_DIR}/ - LDA results and enhanced comparison")
        print(f"   3. Review the enhanced comparison report for model selection guidance")
        print(f"\n Visualization directories:")
        print(f"   • {config.OUTPUT_DIR}/visualizations/bertopic/ - BERTopic visualizations")
        print(f"   • {config.COMPARISON_DIR}/visualizations/lda/ - LDA visualizations")
        print(f"   • {config.COMPARISON_DIR}/comparison_visualizations/ - Comparison dashboards")
    else:
        print("\n Enhanced pipeline execution failed")

In [None]:
import os
from google.colab import files

output_dir = "/content/domain_modeling_results"
zip_path = f"{output_dir}.zip"

# Zip the directory
!zip -r "$zip_path" "$output_dir"

# Download the zip file
if os.path.exists(zip_path):
  files.download(zip_path)
else:
  print(f"Zip file not found at {zip_path}")

In [None]:
!pip install evaluate


In [None]:
# =============================================================================
# PHASE 2: MULTI-DIMENSIONAL DATA CITATION ANALYSIS - UPDATED WITH FILTERING
# =============================================================================

import pandas as pd
import numpy as np
import torch
from transformers import (
    AutoTokenizer, AutoModel, AutoModelForTokenClassification,
    AutoModelForSequenceClassification, TrainingArguments, Trainer
)
import spacy
import re
import json
from collections import defaultdict, Counter
from tqdm import tqdm
import logging
from datasets import Dataset
import evaluate
from typing import List, Dict, Tuple, Any
import warnings
import os
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
warnings.filterwarnings('ignore')

print("Phase 2 imports completed successfully!")

# =============================================================================
# PHASE 2 CONFIGURATION
# =============================================================================

class Phase2Config:
    # File paths
    PROCESSED_TEXTS_CSV = "updated.csv"
    TRAIN_LABELS_CSV = "train_labels.csv"
    OUTPUT_DIR = "phase2_data_citation_analysis"

    # Model configurations
    NER_MODEL_NAME = "allenai/scibert_scivocab_uncased"
    CLASSIFICATION_MODEL_NAME = "roberta-base"
    SPACY_MODEL = "en_core_web_sm"

    # Training parameters
    BATCH_SIZE = 16
    LEARNING_RATE = 2e-5
    NUM_EPOCHS = 4
    MAX_LENGTH = 256

    # NER parameters
    CONTEXT_WINDOW = 150

    # Usage type classes
    USAGE_TYPES = ["primary", "secondary", "missing"]
    PURPOSE_TYPES = ["training", "validation", "testing", "evaluation", "benchmarking"]

    # Dataset patterns
    DATASET_PATTERNS = [
        r'\bdataset[s]?\b',
        r'\bdata\s+set[s]?\b',
        r'\bcorpus\b',
        r'\bcollection\b',
        r'\barchive\b',
        r'\brepository\b',
        r'\bbenchmark\b',
        r'\bDB\d+\b',
        r'\b[A-Z]{2,}\d*\s+dataset\b',
    ]

    def __init__(self):
        os.makedirs(self.OUTPUT_DIR, exist_ok=True)

config = Phase2Config()
print(f" Phase 2 configuration initialized")

# =============================================================================
# ENHANCED DATA LOADER WITH FILE TYPE FILTERING
# =============================================================================

class FilteredDataCitationDataLoader:
    def __init__(self):
        self.processed_texts_df = None
        self.train_labels_df = None
        self.merged_data = None
        self.original_label_distribution = None

    def load_and_validate_data(self):
        """Load and validate both data files with file type filtering and proper filename cleaning"""
        print(" Loading and validating Phase 2 data...")

        try:
            # Load processed texts
            self.processed_texts_df = pd.read_csv(config.PROCESSED_TEXTS_CSV)
            print(f" Loaded {len(self.processed_texts_df)} documents from {config.PROCESSED_TEXTS_CSV}")

            # Load training labels
            self.train_labels_df = pd.read_csv(config.TRAIN_LABELS_CSV)
            print(f" Loaded {len(self.train_labels_df)} labels from {config.TRAIN_LABELS_CSV}")

            # Store original label distribution for comparison
            self.original_label_distribution = self.train_labels_df['type'].value_counts().to_dict()
            print(f" Original label distribution in train_labels: {self.original_label_distribution}")

            # Step 1: Filter by file_type (keep only PDF and XML)
            self._filter_by_file_type()

            # Step 2: Clean filenames in both datasets
            self._clean_filenames()

            # Step 3: Merge datasets
            self._merge_datasets()

            print(f" Final dataset distribution:")
            final_distribution = self.merged_data['type'].value_counts().to_dict()
            print(f"   - Usage types: {final_distribution}")

            # Compare distributions
            self._compare_label_distributions()

            return self.merged_data

        except Exception as e:
            print(f" Data loading failed: {e}")
            import traceback
            traceback.print_exc()
            # Fallback: use processed texts only
            print(" Fallback: Using processed texts only")
            self.merged_data = self.processed_texts_df.copy()
            self.merged_data['type'] = 'primary'  # Default type
            self.merged_data['label_source'] = 'fallback'
            return self.merged_data

    def _filter_by_file_type(self):
        """Filter processed texts to keep only PDF and XML files"""
        print("\n FILTERING BY FILE TYPE:")

        # Check if file_type column exists
        if 'file_type' not in self.processed_texts_df.columns:
            print(" 'file_type' column not found in processed_combined_texts.csv")
            print(" Proceeding without file type filtering")
            return

        # Get initial counts
        initial_count = len(self.processed_texts_df)
        file_type_counts = self.processed_texts_df['file_type'].value_counts().to_dict()
        print(f"   • Initial file type distribution: {file_type_counts}")

        # Filter to keep only PDF and XML files
        valid_file_types = ['pdf', 'xml', 'PDF', 'XML']
        self.processed_texts_df = self.processed_texts_df[
            self.processed_texts_df['file_type'].isin(valid_file_types)
        ]

        # Get filtered counts
        filtered_count = len(self.processed_texts_df)
        filtered_file_type_counts = self.processed_texts_df['file_type'].value_counts().to_dict()

        print(f"   • After filtering (PDF/XML only): {filtered_file_type_counts}")
        print(f"   • Removed {initial_count - filtered_count} documents, kept {filtered_count} documents")

    def _clean_filenames(self):
        """Clean filenames by removing extensions and standardizing format"""
        print("\n CLEANING FILENAMES:")

        # Clean processed texts filenames (remove .pdf, .xml extensions)
        self.processed_texts_df['filename_clean'] = self.processed_texts_df['filename'].apply(
            lambda x: self._remove_file_extensions(str(x))
        )

        # Clean train labels filenames (already clean, but standardize)
        self.train_labels_df['filename_clean'] = self.train_labels_df['filename'].apply(
            lambda x: str(x).strip().lower()
        )

        print(f"   • Processed texts sample (cleaned): {self.processed_texts_df['filename_clean'].head(3).tolist()}")
        print(f"   • Train labels sample (cleaned): {self.train_labels_df['filename_clean'].head(3).tolist()}")

    def _remove_file_extensions(self, filename):
        """Remove common file extensions from filename"""
        # Remove .pdf, .xml, .txt, etc. and convert to lowercase
        cleaned = filename.lower().strip()
        # Remove common extensions
        extensions = ['.pdf', '.xml', '.txt', '.csv', '.json', '.html']
        for ext in extensions:
            if cleaned.endswith(ext):
                cleaned = cleaned[:-len(ext)]
                break  # Remove only one extension
        return cleaned

    def _merge_datasets(self):
        """Merge datasets on cleaned filenames"""
        print("\n MERGING DATASETS:")

        # Find common filenames
        common_filenames = set(self.processed_texts_df['filename_clean']).intersection(
            set(self.train_labels_df['filename_clean'])
        )

        print(f"   • Common filenames found: {len(common_filenames)}")

        if len(common_filenames) > 0:
            # Merge on cleaned filenames
            self.merged_data = pd.merge(
                self.processed_texts_df,
                self.train_labels_df,
                left_on='filename_clean',
                right_on='filename_clean',
                how='inner',
                suffixes=('_text', '_label')
            )
            self.merged_data['label_source'] = 'original'
            print(f" Successfully merged {len(self.merged_data)} documents with original labels")

            # Show sample of matched filenames
            print(f" Sample matched filenames (first 5):")
            for f in list(common_filenames)[:5]:
                print(f"   - {f}")
        else:
            print(" No common filenames found after cleaning. Using enhanced inference.")
            # Use processed texts with enhanced synthetic labels
            self.merged_data = self.processed_texts_df.copy()
            # Add enhanced synthetic usage types
            self.merged_data['type'] = self.merged_data['processed_text'].apply(
                lambda x: self._enhanced_infer_usage_type(x) if pd.notna(x) else 'unknown'
            )
            self.merged_data['label_source'] = 'inferred'
            print(f" Using enhanced inference: {len(self.merged_data)} documents")

    def _enhanced_infer_usage_type(self, text):
        """Enhanced usage type inference with better patterns"""
        if not isinstance(text, str) or len(text.strip()) < 50:
            return 'unknown'

        text_lower = text.lower()

        # Enhanced heuristic-based inference
        primary_indicators = [
            'our dataset', 'we collected', 'our collection', 'we compiled',
            'collected by us', 'gathered by the authors', 'original dataset',
            'new dataset', 'novel dataset', 'this work introduces', 'we created',
            'developed by us', 'constructed by the authors', 'custom dataset'
        ]

        secondary_indicators = [
            'existing dataset', 'previous work', 'benchmark dataset',
            'standard dataset', 'publicly available', 'downloaded from',
            'obtained from', 'provided by', 'well-known dataset', 'established dataset',
            'widely used', 'popular dataset', 'reference dataset', 'public dataset'
        ]

        missing_indicators = [
            'no dataset', 'without data', 'data not available', 'lack of data',
            'unavailable data', 'restricted access', 'proprietary data',
            'confidential data', 'cannot share', 'not publicly available'
        ]

        # Check for primary indicators (highest priority)
        primary_score = sum(1 for indicator in primary_indicators if indicator in text_lower)

        # Check for secondary indicators
        secondary_score = sum(1 for indicator in secondary_indicators if indicator in text_lower)

        # Check for missing indicators
        missing_score = sum(1 for indicator in missing_indicators if indicator in text_lower)

        # Determine the type based on scores
        if primary_score > 0 and primary_score >= secondary_score:
            return 'primary'
        elif secondary_score > 0 and secondary_score > primary_score:
            return 'secondary'
        elif missing_score > 2:  # Need multiple indicators for missing
            return 'missing'
        else:
            return 'unknown'

    def _compare_label_distributions(self):
        """Compare original and final label distributions"""
        if self.original_label_distribution and 'label_source' in self.merged_data.columns:
            original_labels = self.merged_data[self.merged_data['label_source'] == 'original']['type'].value_counts().to_dict()
            inferred_labels = self.merged_data[self.merged_data['label_source'] == 'inferred']['type'].value_counts().to_dict() if 'inferred' in self.merged_data['label_source'].values else {}

            print(f"\n LABEL DISTRIBUTION COMPARISON:")
            print(f"   Original labels in train_labels.csv: {self.original_label_distribution}")
            if original_labels:
                print(f"   After merge - Original labels used: {original_labels}")
            if inferred_labels:
                print(f"   After merge - Inferred labels: {inferred_labels}")

    def get_texts_for_analysis(self):
        """Get processed texts for analysis with fallback"""
        if self.merged_data is None:
            self.load_and_validate_data()

        # Use processed_text if available, otherwise use other text columns
        text_columns = ['processed_text', 'final_cleaned_text', 'raw_text']
        for col in text_columns:
            if col in self.merged_data.columns:
                texts = self.merged_data[col].fillna('').tolist()
                break
        else:
            # If no text columns found, create empty list
            texts = [''] * len(self.merged_data)

        # Use the cleaned filename for consistency
        filenames = self.merged_data['filename_clean'].tolist()
        usage_types = self.merged_data['type'].tolist()
        label_sources = self.merged_data['label_source'].tolist()

        print(f" Texts for analysis: {len(texts)} documents")
        print(f" Label sources: {Counter(label_sources)}")
        print(f" Text length stats: avg={np.mean([len(str(t)) for t in texts]):.0f}, "
              f"min={np.min([len(str(t)) for t in texts])}, max={np.max([len(str(t)) for t in texts])}")

        return texts, filenames, usage_types, label_sources


class HybridDatasetNER:
    def __init__(self):
        self.rule_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in config.DATASET_PATTERNS]
        self.spacy_nlp = None
        self.initialize_models()

    def initialize_models(self):
        """Initialize models with error handling"""
        try:
            self.spacy_nlp = spacy.load(config.SPACY_MODEL)
            print(f" spaCy model loaded: {config.SPACY_MODEL}")
        except OSError:
            print(f" spaCy model not found. Using rule-based only.")
            self.spacy_nlp = None

    def rule_based_dataset_detection(self, text: str) -> List[Dict]:
        """Enhanced rule-based dataset mention detection"""
        if not isinstance(text, str):
            return []

        mentions = []

        # Basic pattern matching
        for pattern in self.rule_patterns:
            for match in pattern.finditer(text):
                start, end = match.span()
                mentions.append({
                    'text': text[start:end],
                    'start': start,
                    'end': end,
                    'type': 'DATASET',
                    'source': 'rule_based',
                    'confidence': 0.7
                })

        # Enhanced contextual patterns
        enhanced_patterns = [
            (r'\b(using|using the|using our|employed|utilized)\s+([A-Za-z0-9\s\-]+?\s+(?:dataset|corpus|collection))', 2),
            (r'\b(dataset|corpus|collection|benchmark)\s+of\s+([A-Z][A-Za-z0-9\s\-]+)', 0),
            (r'\b([A-Z]{2,}(?:\-\d+)?\s+(?:dataset|corpus))', 0),
            (r'\b(ImageNet|CIFAR|MNIST|COCO|SQuAD|GLUE|UCI)\b', 0),  # Common dataset names
        ]

        for pattern, group_idx in enhanced_patterns:
            for match in re.finditer(pattern, text, re.IGNORECASE):
                mention_text = match.group(group_idx) if group_idx > 0 else match.group(0)
                if mention_text:
                    mentions.append({
                        'text': mention_text.strip(),
                        'start': match.start(),
                        'end': match.end(),
                        'type': 'DATASET_MENTION',
                        'source': 'enhanced_rule',
                        'confidence': 0.8
                    })

        return mentions

    def detect_dataset_mentions(self, text: str) -> List[Dict]:
        """Main method to detect dataset mentions"""
        if not text or len(text.strip()) < 20:
            return []

        try:
            # Get rule-based mentions
            rule_mentions = self.rule_based_dataset_detection(text)

            # Apply additional filtering and confidence scoring
            filtered_mentions = []
            for mention in rule_mentions:
                # Enhance confidence based on context
                enhanced_confidence = self._enhance_confidence(mention, text)
                mention['confidence'] = enhanced_confidence

                # Only keep mentions with reasonable confidence
                if enhanced_confidence > 0.3:
                    # Extract context
                    mention.update(self.extract_context_around_mention(text, mention))
                    filtered_mentions.append(mention)

            return filtered_mentions

        except Exception as e:
            print(f" Dataset mention detection failed: {e}")
            return []

    def _enhance_confidence(self, mention: Dict, text: str) -> float:
        """Enhance confidence score based on contextual clues"""
        confidence = mention.get('confidence', 0.5)

        mention_text = mention['text'].lower()
        context = text[max(0, mention['start']-50):min(len(text), mention['end']+50)].lower()

        # Boost confidence for specific patterns
        if any(word in mention_text for word in ['imagenet', 'cifar', 'mnist', 'coco', 'squad', 'glue']):
            confidence += 0.3

        # Boost for technical context
        technical_terms = ['train', 'test', 'validate', 'evaluate', 'accuracy', 'performance']
        if any(term in context for term in technical_terms):
            confidence += 0.2

        # Penalize very short mentions without context
        if len(mention_text) < 5 and 'dataset' not in mention_text:
            confidence -= 0.2

        return min(max(confidence, 0.1), 1.0)

    def extract_context_around_mention(self, text: str, mention: Dict) -> Dict:
        """Extract context around dataset mention"""
        start, end = mention['start'], mention['end']

        context_start = max(0, start - config.CONTEXT_WINDOW)
        context_end = min(len(text), end + config.CONTEXT_WINDOW)

        left_context = text[context_start:start]
        right_context = text[end:context_end]
        mention_text = text[start:end]

        return {
            'left_context': left_context.strip(),
            'mention_text': mention_text,
            'right_context': right_context.strip(),
            'full_context': f"{left_context} {mention_text} {right_context}".strip(),
            'context_start': context_start,
            'context_end': context_end
        }



class PurposeClassifier:
    def __init__(self):
        self.purpose_keywords = self._initialize_purpose_keywords()

    def _initialize_purpose_keywords(self) -> Dict[str, List[str]]:
        """Initialize purpose-specific keywords with enhanced coverage"""
        return {
            'training': [
                'train', 'training', 'learn', 'learning', 'fit', 'fitting', 'optimize',
                'optimization', 'parameter', 'weight', 'model training', 'train set',
                'training data', 'learning algorithm', 'backpropagation', 'gradient descent'
            ],
            'validation': [
                'validate', 'validation', 'tune', 'tuning', 'hyperparameter', 'development set',
                'dev set', 'validation set', 'model selection', 'parameter tuning',
                'cross validation', 'early stopping', 'hyperparameter optimization'
            ],
            'testing': [
                'test', 'testing', 'evaluate', 'evaluation', 'assess', 'assessment',
                'test set', 'testing data', 'performance', 'accuracy', 'result',
                'benchmark', 'comparison', 'final evaluation', 'generalization'
            ],
            'evaluation': [
                'evaluate', 'evaluation', 'assess', 'assessment', 'measure', 'metric',
                'performance', 'accuracy', 'precision', 'recall', 'f1', 'auc', 'roc',
                'metrics', 'quantitative analysis', 'qualitative analysis'
            ],
            'benchmarking': [
                'benchmark', 'benchmarking', 'compare', 'comparison', 'baseline',
                'state of the art', 'sota', 'competitive', 'performance comparison',
                'leaderboard', 'standard benchmark', 'reference implementation'
            ]
        }

    def classify_purpose(self, context: Dict) -> Dict:
        """Enhanced purpose classification with contextual analysis"""
        full_context = context['full_context'].lower()

        purpose_scores = {purpose: 0.0 for purpose in self.purpose_keywords.keys()}

        # Keyword-based scoring with position weighting
        for purpose, keywords in self.purpose_keywords.items():
            for keyword in keywords:
                if keyword in full_context:
                    # Calculate position-based weight
                    weight = 1.0
                    keyword_pos = full_context.find(keyword)

                    # Higher weight for keywords closer to mention
                    mention_pos = len(context['left_context'])
                    distance = abs(keyword_pos - mention_pos)
                    if distance < 50:
                        weight *= 1.5
                    elif distance < 100:
                        weight *= 1.2

                    purpose_scores[purpose] += weight

        # Normalize scores
        total_score = sum(purpose_scores.values())
        if total_score > 0:
            for purpose in purpose_scores:
                purpose_scores[purpose] /= total_score

        # Determine primary purpose
        if purpose_scores:
            primary_purpose = max(purpose_scores.items(), key=lambda x: x[1])
            confidence = primary_purpose[1]
            primary_purpose = primary_purpose[0]
        else:
            primary_purpose = 'unknown'
            confidence = 0.0

        return {
            'primary_purpose': primary_purpose,
            'confidence': confidence,
            'purpose_scores': purpose_scores
        }



class DatasetImpactAnalyzer:
    def __init__(self):
        self.dataset_usage_stats = defaultdict(lambda: {
            'mention_count': 0,
            'papers_count': 0,
            'usage_types': Counter(),
            'purposes': Counter(),
            'contexts': []
        })
        self.label_source_stats = defaultdict(Counter)

    def analyze_dataset_impact(self, analysis_results: List[Dict]) -> Dict:
        """Analyze impact metrics with robust error handling"""
        print(" Analyzing dataset impact metrics...")

        if not analysis_results:
            return self._get_empty_impact_analysis()

        try:
            # Aggregate usage statistics
            for result in analysis_results:
                if 'dataset_mentions' not in result:
                    continue

                filename = result.get('filename', 'unknown')
                label_source = result.get('label_source', 'unknown')
                mentions = result['dataset_mentions']

                for mention in mentions:
                    dataset_name = mention.get('text', 'unknown_dataset')
                    if dataset_name == 'unknown_dataset':
                        continue

                    stats = self.dataset_usage_stats[dataset_name]

                    stats['mention_count'] += 1
                    usage_type = mention.get('usage_type', 'unknown')
                    stats['usage_types'][usage_type] += 1
                    stats['purposes'][mention.get('purpose', {}).get('primary_purpose', 'unknown')] += 1
                    stats['contexts'].append({
                        'filename': filename,
                        'usage_type': usage_type,
                        'purpose': mention.get('purpose', {}).get('primary_purpose', 'unknown'),
                        'label_source': label_source
                    })

                    # Track label source distribution
                    self.label_source_stats[label_source][usage_type] += 1

            # Remove duplicates for papers count
            for dataset_name, stats in self.dataset_usage_stats.items():
                unique_papers = set(context['filename'] for context in stats['contexts'])
                stats['papers_count'] = len(unique_papers)

            # Calculate impact metrics
            impact_metrics = {}
            for dataset_name, stats in self.dataset_usage_stats.items():
                impact_metrics[dataset_name] = self._calculate_dataset_metrics(dataset_name, stats)

            reproducibility_score = self._calculate_reproducibility_score(impact_metrics)

            return {
                'dataset_impact_metrics': impact_metrics,
                'reproducibility_score': reproducibility_score,
                'total_datasets_analyzed': len(impact_metrics),
                'total_mentions': sum(stats['mention_count'] for stats in self.dataset_usage_stats.values()),
                'label_source_distribution': dict(self.label_source_stats),
                'usage_type_by_source': self._analyze_usage_by_source()
            }

        except Exception as e:
            print(f" Impact analysis failed: {e}")
            return self._get_empty_impact_analysis()

    def _analyze_usage_by_source(self) -> Dict:
        """Analyze usage type distribution by label source"""
        usage_by_source = {}
        for source, counter in self.label_source_stats.items():
            usage_by_source[source] = dict(counter)
        return usage_by_source

    def _get_empty_impact_analysis(self):
        """Return empty impact analysis when no data is available"""
        return {
            'dataset_impact_metrics': {},
            'reproducibility_score': 0.0,
            'total_datasets_analyzed': 0,
            'total_mentions': 0,
            'label_source_distribution': {},
            'usage_type_by_source': {}
        }

    def _calculate_dataset_metrics(self, dataset_name: str, stats: Dict) -> Dict:
        """Calculate comprehensive metrics for a dataset"""
        total_mentions = stats['mention_count']
        total_papers = stats['papers_count']

        # Usage diversity
        usage_diversity = len(stats['usage_types']) / len(config.USAGE_TYPES)

        # Purpose diversity
        purpose_diversity = len(stats['purposes']) / len(config.PURPOSE_TYPES)

        # Impact score
        frequency_score = min(1.0, total_mentions / 10)
        diversity_score = (usage_diversity + purpose_diversity) / 2
        impact_score = 0.6 * frequency_score + 0.4 * diversity_score

        # Reproducibility indicators
        reproducibility_indicators = {
            'multiple_usage_contexts': len(stats['usage_types']) > 1,
            'multiple_papers': total_papers > 1,
            'consistent_purposes': len(stats['purposes']) >= 1
        }

        reproducibility_score = sum(reproducibility_indicators.values()) / len(reproducibility_indicators)

        return {
            'dataset_name': dataset_name,
            'mention_count': total_mentions,
            'papers_count': total_papers,
            'usage_type_distribution': dict(stats['usage_types']),
            'purpose_distribution': dict(stats['purposes']),
            'usage_diversity': usage_diversity,
            'purpose_diversity': purpose_diversity,
            'impact_score': impact_score,
            'reproducibility_score': reproducibility_score,
            'reproducibility_indicators': reproducibility_indicators
        }

    def _calculate_reproducibility_score(self, impact_metrics: Dict) -> float:
        """Calculate overall reproducibility score"""
        if not impact_metrics:
            return 0.0

        total_reproducibility = sum(metric['reproducibility_score'] for metric in impact_metrics.values())
        return total_reproducibility / len(impact_metrics)



class FixedResultsComparator:
    def __init__(self):
        self.comparison_results = {}

    def compare_with_original_labels(self, analysis_results: List[Dict], original_labels_df: pd.DataFrame):
        """Compare detected results with original labels using cleaned filenames"""
        print("\n COMPARING RESULTS WITH ORIGINAL LABELS")

        # Clean original labels filenames for comparison
        original_labels_df_clean = original_labels_df.copy()
        original_labels_df_clean['filename_clean'] = original_labels_df_clean['filename'].apply(
            lambda x: str(x).strip().lower()
        )

        # Extract detected usage types
        detected_usage_types = []
        original_usage_types = []
        matched_files = []

        for result in analysis_results:
            filename = result.get('filename', '')  # This should be the cleaned filename
            detected_type = result.get('usage_type', 'unknown')

            # Find corresponding original label using cleaned filename
            original_row = original_labels_df_clean[original_labels_df_clean['filename_clean'] == filename]
            if not original_row.empty:
                original_type = original_row.iloc[0]['type']
                detected_usage_types.append(detected_type)
                original_usage_types.append(original_type)
                matched_files.append(filename)

        if not detected_usage_types:
            print(" No matching files found for comparison")
            print(" This might indicate a filename cleaning issue")
            return None

        # Calculate accuracy and metrics
        accuracy = accuracy_score(original_usage_types, detected_usage_types)
        class_report = classification_report(original_usage_types, detected_usage_types, output_dict=True)
        cm = confusion_matrix(original_usage_types, detected_usage_types, labels=['Primary', 'Secondary', 'Missing'])

        self.comparison_results = {
            'accuracy': accuracy,
            'classification_report': class_report,
            'confusion_matrix': cm.tolist(),
            'labels': ['Primary', 'Secondary', 'Missing'],
            'sample_size': len(detected_usage_types),
            'matched_files_count': len(matched_files)
        }

        print(f" Comparison completed:")
        print(f"   • Sample size: {len(detected_usage_types)} files")
        print(f"   • Accuracy: {accuracy:.3f}")
        print(f"   • Matched files: {len(matched_files)}")

        return self.comparison_results

    def generate_comparison_report(self, output_dir: str):
        """Generate comprehensive comparison report"""
        if not self.comparison_results:
            print(" No comparison results available")
            return

        report_path = os.path.join(output_dir, "label_comparison_report.txt")

        with open(report_path, 'w') as f:
            f.write("LABEL COMPARISON REPORT\n")
            f.write("=" * 50 + "\n\n")

            f.write(f"Sample Size: {self.comparison_results['sample_size']} files\n")
            f.write(f"Overall Accuracy: {self.comparison_results['accuracy']:.3f}\n")
            f.write(f"Matched Files: {self.comparison_results['matched_files_count']}\n\n")

            f.write("Classification Report:\n")
            for class_name, metrics in self.comparison_results['classification_report'].items():
                if class_name in ['Primary', 'Secondary', 'Missing']:
                    f.write(f"  {class_name}:\n")
                    f.write(f"    Precision: {metrics['precision']:.3f}\n")
                    f.write(f"    Recall: {metrics['recall']:.3f}\n")
                    f.write(f"    F1-Score: {metrics['f1-score']:.3f}\n")
                    f.write(f"    Support: {metrics['support']}\n\n")

        print(f" Comparison report saved to: {report_path}")



class FilteredDataCitationAnalysisPipeline:
    def __init__(self):
        self.data_loader = FilteredDataCitationDataLoader()
        self.ner_system = HybridDatasetNER()
        self.purpose_classifier = PurposeClassifier()
        self.impact_analyzer = DatasetImpactAnalyzer()
        self.comparator = FixedResultsComparator()

    def run_complete_analysis(self):
        """Run the complete Phase 2 analysis pipeline with file type filtering"""
        print(" STARTING PHASE 2: MULTI-DIMENSIONAL DATA CITATION ANALYSIS")
        print("=" * 80)

        try:
            # Step 1: Load and prepare data with filtering
            print("\n STEP 1: DATA LOADING AND PREPARATION")
            data = self.data_loader.load_and_validate_data()
            texts, filenames, usage_types, label_sources = self.data_loader.get_texts_for_analysis()

            if len(texts) == 0:
                print(" No texts available for analysis")
                return None, None

            # Step 2: Dataset mention detection
            print("\n STEP 2: HYBRID DATASET MENTION DETECTION")
            dataset_mentions_results = self._detect_dataset_mentions(texts, filenames, usage_types, label_sources)

            # Step 3: Purpose classification
            print("\n STEP 3: PURPOSE CLASSIFICATION AND METHODOLOGY EXTRACTION")
            purpose_analysis_results = self._analyze_purposes(dataset_mentions_results)

            # Step 4: Impact analysis
            print("\n STEP 4: DATASET IMPACT ANALYSIS")
            impact_analysis = self.impact_analyzer.analyze_dataset_impact(purpose_analysis_results)

            # Step 5: Compare with original labels
            print("\n STEP 5: COMPARISON WITH ORIGINAL LABELS")
            comparison_results = self.comparator.compare_with_original_labels(
                purpose_analysis_results, self.data_loader.train_labels_df
            )

            # Step 6: Save results
            print("\n STEP 6: SAVING COMPREHENSIVE RESULTS")
            self._save_comprehensive_results(purpose_analysis_results, impact_analysis, comparison_results)

            print("\n PHASE 2 COMPLETED SUCCESSFULLY!")
            return purpose_analysis_results, impact_analysis

        except Exception as e:
            print(f" Phase 2 pipeline failed: {e}")
            import traceback
            traceback.print_exc()
            return None, None

    def _detect_dataset_mentions(self, texts: List[str], filenames: List[str], usage_types: List[str], label_sources: List[str]) -> List[Dict]:
        """Detect dataset mentions in all texts"""
        results = []
        total_mentions = 0

        for i, (text, filename, usage_type, label_source) in enumerate(tqdm(
            zip(texts, filenames, usage_types, label_sources),
            total=len(texts),
            desc="Detecting dataset mentions"
        )):
            if not text or len(text.strip()) < 50:
                continue

            try:
                # Detect dataset mentions
                mentions = self.ner_system.detect_dataset_mentions(text)

                # Add usage type to each mention
                for mention in mentions:
                    mention['usage_type'] = usage_type

                results.append({
                    'filename': filename,
                    'text_length': len(text),
                    'dataset_mentions': mentions,
                    'mentions_count': len(mentions),
                    'usage_type': usage_type,
                    'label_source': label_source
                })

                total_mentions += len(mentions)

            except Exception as e:
                print(f" Error processing document {filename}: {e}")
                continue

        print(f" Dataset mention detection completed: {total_mentions} mentions found in {len(results)} documents")
        return results

    def _analyze_purposes(self, mention_results: List[Dict]) -> List[Dict]:
        """Analyze purposes for dataset mentions"""
        analyzed_results = []

        for result in tqdm(mention_results, desc="Analyzing purposes"):
            analyzed_mentions = []

            for mention in result['dataset_mentions']:
                try:
                    # Classify purpose
                    purpose_analysis = self.purpose_classifier.classify_purpose(mention)
                    mention['purpose'] = purpose_analysis
                    analyzed_mentions.append(mention)
                except Exception as e:
                    print(f" Purpose analysis failed for mention: {e}")
                    continue

            result['dataset_mentions'] = analyzed_mentions
            analyzed_results.append(result)

        print(f" Purpose analysis completed for {len(analyzed_results)} documents")
        return analyzed_results

    def _save_comprehensive_results(self, analysis_results: List[Dict], impact_analysis: Dict, comparison_results: Dict):
        """Save all Phase 2 results with robust error handling"""
        output_dir = config.OUTPUT_DIR

        try:
            # Save detailed analysis results
            detailed_results = []
            for result in analysis_results:
                for mention in result.get('dataset_mentions', []):
                    detailed_results.append({
                        'filename': result.get('filename', 'unknown'),
                        'dataset_name': mention.get('text', 'unknown'),
                        'mention_confidence': mention.get('confidence', 0),
                        'usage_type': mention.get('usage_type', 'unknown'),
                        'primary_purpose': mention.get('purpose', {}).get('primary_purpose', 'unknown'),
                        'purpose_confidence': mention.get('purpose', {}).get('confidence', 0),
                        'left_context': mention.get('left_context', '')[:200],
                        'right_context': mention.get('right_context', '')[:200],
                        'detection_source': mention.get('source', 'unknown'),
                        'label_source': result.get('label_source', 'unknown')
                    })

            if detailed_results:
                detailed_df = pd.DataFrame(detailed_results)
                detailed_file = os.path.join(output_dir, "detailed_citation_analysis.csv")
                detailed_df.to_csv(detailed_file, index=False)
                print(f" Detailed analysis saved to: {detailed_file}")

                # Save impact analysis
                impact_file = os.path.join(output_dir, "dataset_impact_analysis.json")
                with open(impact_file, 'w') as f:
                    json.dump(impact_analysis, f, indent=2, default=str)
                print(f" Impact analysis saved to: {impact_file}")

                # Save impact report
                report_file = os.path.join(output_dir, "impact_analysis_report.txt")
                report = self._generate_impact_report(impact_analysis, len(detailed_results))
                with open(report_file, 'w') as f:
                    f.write(report)
                print(f" Impact report saved to: {report_file}")

                # Save comparison results
                if comparison_results:
                    comparison_file = os.path.join(output_dir, "label_comparison_results.json")
                    with open(comparison_file, 'w') as f:
                        json.dump(comparison_results, f, indent=2)
                    print(f" Comparison results saved to: {comparison_file}")

                    # Generate comparison report
                    self.comparator.generate_comparison_report(output_dir)
            else:
                print(" No detailed results to save")

        except Exception as e:
            print(f" Error saving results: {e}")

    def _generate_impact_report(self, impact_analysis: Dict, total_mentions: int) -> str:
        """Generate impact report"""
        report = []
        report.append("=" * 80)
        report.append("DATASET IMPACT ANALYSIS REPORT")
        report.append("=" * 80)

        report.append(f"\n OVERVIEW:")
        report.append(f"   • Total datasets analyzed: {impact_analysis['total_datasets_analyzed']}")
        report.append(f"   • Total mentions: {impact_analysis['total_mentions']}")
        report.append(f"   • Overall reproducibility score: {impact_analysis['reproducibility_score']:.3f}")

        # Label source distribution
        if 'label_source_distribution' in impact_analysis:
            report.append(f"\n LABEL SOURCE DISTRIBUTION:")
            for source, count in impact_analysis['label_source_distribution'].items():
                report.append(f"   • {source}: {sum(count.values())} mentions")

        # Usage type by source
        if 'usage_type_by_source' in impact_analysis:
            report.append(f"\n USAGE TYPE BY SOURCE:")
            for source, usage_dist in impact_analysis['usage_type_by_source'].items():
                report.append(f"   • {source}: {usage_dist}")

        if impact_analysis['dataset_impact_metrics']:
            report.append(f"\n TOP DATASETS BY IMPACT SCORE:")
            top_datasets = sorted(
                impact_analysis['dataset_impact_metrics'].items(),
                key=lambda x: x[1]['impact_score'],
                reverse=True
            )[:5]

            for dataset_name, metrics in top_datasets:
                report.append(f"   • {dataset_name}:")
                report.append(f"     - Impact Score: {metrics['impact_score']:.3f}")
                report.append(f"     - Mentions: {metrics['mention_count']}")
                report.append(f"     - Papers: {metrics['papers_count']}")
                report.append(f"     - Reproducibility: {metrics['reproducibility_score']:.3f}")
                report.append(f"     - Usage Types: {metrics['usage_type_distribution']}")
        else:
            report.append(f"\n No dataset impact metrics available")

        return "\n".join(report)


if __name__ == "__main__":
    print(" PHASE 2: MULTI-DIMENSIONAL DATA CITATION ANALYSIS - FILTERED VERSION")
    print("=" * 80)

    # Execute Phase 2 pipeline
    phase2_pipeline = FilteredDataCitationAnalysisPipeline()
    analysis_results, impact_analysis = phase2_pipeline.run_complete_analysis()

    if analysis_results is not None:
        # Print final summary
        total_mentions = sum(len(result['dataset_mentions']) for result in analysis_results)
        unique_datasets = set()
        label_sources = Counter()

        for result in analysis_results:
            for mention in result['dataset_mentions']:
                unique_datasets.add(mention['text'])
            label_sources[result.get('label_source', 'unknown')] += 1

        print(f"\n PHASE 2 FINAL SUMMARY:")
        print("=" * 50)
        print(f" Detection Results:")
        print(f"   • Documents analyzed: {len(analysis_results)}")
        print(f"   • Dataset mentions found: {total_mentions}")
        print(f"   • Unique datasets: {len(unique_datasets)}")
        print(f"   • Label sources: {dict(label_sources)}")

        if impact_analysis:
            print(f" Impact Analysis:")
            print(f"   • Overall reproducibility: {impact_analysis['reproducibility_score']:.3f}")
            print(f"   • Total datasets with impact metrics: {impact_analysis['total_datasets_analyzed']}")

        print(f"\n Results saved to: {config.OUTPUT_DIR}")
        print(" Phase 2 completed successfully!")

    else:
        print("\n Phase 2 execution failed")

In [None]:
import os
from google.colab import files

output_dir = "/content/phase2_data_citation_analysis"
zip_path = f"{output_dir}.zip"

# Zip the directory
!zip -r "$zip_path" "$output_dir"

# Download the zip file
if os.path.exists(zip_path):
  files.download(zip_path)
else:
  print(f"Zip file not found at {zip_path}")