<a href="https://colab.research.google.com/github/Chaaa76/nlp-model/blob/main/Finalnlp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# %% [code]
# Install required libraries/dependencies
!pip install -q transformers sentence-transformers pandas scikit-learn numpy optuna rank_bm25 datasets
!pip install -q --upgrade transformers
!pip install -q rank_bm25
!pip install -q --upgrade datasets

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

import pandas as pd
import numpy as np
import torch
import json
import logging
import optuna
from datetime import datetime
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import ndcg_score
from rank_bm25 import BM25Okapi
from typing import List, Dict, Tuple, Optional, Union
from sentence_transformers import SentenceTransformer, CrossEncoder, losses, InputExample
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import time
import re

# Enhanced logging configuration
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('ordinance_retrieval.log')
    ]
)
logger = logging.getLogger("OrdinanceRetrieval")

class OrdinanceRetrievalSystem:
    def __init__(self):
        """Initialize the retrieval system with default parameters"""
        self.df = None
        self.model = None
        self.cross_encoder = None
        self.embeddings = None
        self.bm25 = None
        self.tokenizer = None
        self.id_to_idx = {}
        self.available_ordinance_ids = set()
        self.best_params = {
            'semantic_weight': 0.6,
            'ce_weight': 0.5,
            'batch_size': 8,
            'learning_rate': 3e-5,
            'epochs': 2
        }
        self.query_expansion_terms = {
        "tax": ["levy", "revenue", "assessment", "dues", "collection", "fiscal", "fee", "amusement tax", "real property tax"],
        "property": ["real estate", "land", "building", "house", "lot", "title", "structure", "premises", "estate", "realty"],
        "business": ["enterprise", "commerce", "establishment", "shop", "store", "business permit", "commercial", "firm", "company", "trade"],
        "health": ["sanitation", "hygiene", "medical", "clinic", "cleanliness", "public health", "healthcare", "infection control", "health program"],
        "smoking": ["tobacco", "cigarette", "vape", "e-cigarette", "nicotine", "no smoking", "smoke-free", "secondhand smoke", "smoking ban"],
        "alcohol": ["liquor", "alcoholic drink", "alcoholic beverage", "booze", "beer", "wine", "hard drinks", "intoxicating drink", "drinking ban", "liquor ban", "alcoholic","drinking alcohol","sell liquor to minors"],
        "soft drinks": ["carbonated drink", "soda", "beverage", "bottled drink", "sweetened drink", "soft beverage", "colas", "refreshment"],
        "license": ["permit", "authorization", "certification", "registration", "approval", "license to operate", "business license", "franchise"],
        "waste": ["waste disposal", "garbage", "trash", "refuse", "rubbish", "disposal", "solid waste", "junk", "basura", "collection schedule", "segregation"],
        "littering": ["waste disposal", "garbage", "trash", "illegal dumping", "rubbish", "improper disposal", "throwing waste", "scattering garbage", "basura"],
        "cleanliness": ["sanitation", "cleaning", "public space", "tidy", "sweeping", "maintenance", "clean-up", "street cleaning", "city hygiene"],
        "curfew": ["time restriction", "minor restriction", "kids outside", "night ban", "youth curfew", "curfew hours", "discipline hours", "prohibited time"],
        "noise": ["loud sounds", "speakers", "karaoke", "disturbance", "amplifier", "sound pollution", "unnecessary noise", "noise control"],
        "vehicles": ["cars", "motorcycles", "jeepney", "tricycle", "parking", "traffic", "automobile", "vehicle regulation", "transport"],
        "parking": ["no parking", "parked car", "tow", "illegal parking", "street parking", "reserved parking", "paid parking", "parking violation"],
        "market": ["public market", "vendor", "stall", "palengke", "selling area", "wet market", "dry goods", "marketplace", "tiangge"],
        "environment": ["pollution", "air quality", "green", "climate", "clean air", "ecology", "waste management", "environmental protection", "carbon"],
        "vending": ["street vendor", "hawker", "sidewalk selling", "illegal seller", "ambulant vendor", "vending stall", "peddling", "vendor regulation"],
        "construction": ["building", "renovation", "development", "permit to build", "excavation", "infrastructure", "structural work", "building code"],
        "zoning": ["land use", "residential area", "commercial zone", "reclassification", "urban planning", "zone ordinance", "rezone", "land designation"],
        "penalty": ["fine", "punishment", "fee", "ticket", "sanction", "imprisonment", "penalized", "violation consequence", "offense fee"],
        "school": ["education", "student", "learning", "public school", "academic", "elementary", "high school", "teacher", "educational institution"],
        "barangay": ["local government", "neighborhood", "community office", "barangay hall", "barangay captain", "local unit", "barangay council"],
        "senior citizen": ["elderly", "senior", "discount", "benefit", "ID", "senior card", "pension", "social protection", "elder care","allowance","monthly","old people","old"],
        "PWD": ["disabled", "person with disability", "handicapped", "benefits", "ID", "PWD card", "accessibility", "inclusive", "disability support"],
        "animal": ["dog", "cat", "pet", "stray", "animal control", "bite", "rabies", "pet registration", "veterinary", "animal welfare"]
        }


    def load_data(self, file_path: str) -> pd.DataFrame:
        """Load and preprocess the ordinance data with robust error handling"""
        try:
            logger.info(f"Loading data from {file_path}...")

            # Validate file existence
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Data file not found at {file_path}")

            # Load with proper NA handling and dtype specification
            dtype_mapping = {
                'ordinance_id': str,
                'short_text': str,
                'full_text': str,
                'category': str,
                'fines': str,
                'date_enacted': str,
                'status': str,
                'links': str
            }

            try:
                df = pd.read_csv(
                    file_path,
                    na_values=["nan", "NaN", "NULL", "None", "MISSING", "TOO LONG", ""],
                    dtype=dtype_mapping
                )
            except Exception as e:
                logger.error(f"Failed to read CSV: {str(e)}")
                raise

            # Validate required columns
            required_columns = ['ordinance_id', 'short_text']
            missing_cols = [col for col in required_columns if col not in df.columns]
            if missing_cols:
                raise ValueError(f"Missing required columns: {missing_cols}")

            # Text cleaning with error handling
            text_cols = ["short_text", "full_text", "category", "fines","status","links"]
            for col in text_cols:
                if col in df.columns:
                    try:
                        df[col] = df[col].fillna("").apply(
                            lambda x: " ".join(str(x).split()) if pd.notna(x) else ""
                        )
                    except Exception as e:
                        logger.warning(f"Error cleaning column {col}: {str(e)}")
                        df[col] = df[col].astype(str).fillna("")

            # Handle missing categories
            if "category" in df.columns:
                df["category"] = df["category"].replace("", "Unknown").fillna("Unknown")
            # Handle missing status
            if "status" in df.columns:
                df["status"] = df["status"].replace("", "Status not specified").fillna("Status not specified")

            # Standardize dates with error handling
            if "date_enacted" in df.columns:
                try:
                    df["date_enacted"] = pd.to_datetime(
                        df["date_enacted"],
                        errors="coerce",
                        format='mixed'
                    ).dt.strftime('%B %d, %Y')
                    df["date_enacted"] = df["date_enacted"].fillna("Date not available")
                except Exception as e:
                    logger.warning(f"Error parsing dates: {str(e)}")
                    df["date_enacted"] = "Date not available"

            # Remove very short or empty texts
            if "short_text" in df.columns:
                df = df[df["short_text"].str.len() > 30].copy()

            self.df = df.reset_index(drop=True)
            self.available_ordinance_ids = set(self.df["ordinance_id"].astype(str).unique())
            logger.info(f"Successfully loaded {len(self.df)} ordinances")
            return self.df

        except Exception as e:
            logger.error(f"Critical error in load_data: {str(e)}")
            raise

    def initialize_models(self):
        """Initialize the retrieval models with error handling"""
        try:
            logger.info("Initializing models...")

            # Initialize tokenizer first
            try:
                self.tokenizer = AutoTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
                logger.info("Tokenizer initialized successfully")
            except Exception as e:
                logger.warning(f"Failed to load legal-bert tokenizer, falling back to default: {str(e)}")
                try:
                    self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
                    logger.info("Fallback tokenizer initialized successfully")
                except Exception as e:
                    logger.error(f"Failed to initialize any tokenizer: {str(e)}")
                    self.tokenizer = None

            # Bi-encoder for semantic search with fallback
            try:
                self.model = SentenceTransformer('nlpaueb/legal-bert-base-uncased')
            except Exception as e:
                logger.warning(f"Failed to load legal-bert, falling back to all-MiniLM: {str(e)}")
                self.model = SentenceTransformer('all-MiniLM-L6-v2')

            # Cross-encoder for re-ranking with fallback
            try:
                self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
            except Exception as e:
                logger.warning(f"Failed to load cross-encoder: {str(e)}")
                self.cross_encoder = None

            # BM25 for keyword search
            try:
                if self.df is not None and len(self.df) > 0 and self.tokenizer is not None:
                    tokenized_corpus = [self.tokenizer.tokenize(str(text)) for text in self.df["short_text"]]
                    self.bm25 = BM25Okapi(tokenized_corpus)
                    logger.info("BM25 initialized successfully")
                else:
                    logger.warning("No data or tokenizer available for BM25 initialization")
                    self.bm25 = None
            except Exception as e:
                logger.error(f"Failed to initialize BM25: {str(e)}")
                self.bm25 = None

            logger.info("Models initialized successfully")

        except Exception as e:
            logger.error(f"Error initializing models: {str(e)}")
            raise

    def generate_embeddings(self):
        """Generate embeddings for all ordinances with error handling"""
        try:
            logger.info("Generating embeddings...")

            if self.model is None:
                raise ValueError("Model not initialized")

            if len(self.df) == 0:
                raise ValueError("No data available for embedding generation")

            # Use both short_text and category for better embeddings
            texts = self.df["short_text"].astype(str) + " [SEP] " + self.df["category"].astype(str)

            # Process in chunks to handle memory constraints
            chunk_size = 100
            embeddings = []
            for i in range(0, len(texts), chunk_size):
                chunk = texts[i:i + chunk_size].tolist()
                try:
                    embeddings.append(self.model.encode(chunk, show_progress_bar=False))
                except Exception as e:
                    logger.warning(f"Error encoding chunk {i//chunk_size}: {str(e)}")
                    raise

            self.embeddings = np.concatenate(embeddings)
            self.id_to_idx = {str(id): idx for idx, id in enumerate(self.df["ordinance_id"])}

            logger.info(f"Generated embeddings for {len(self.embeddings)} ordinances")

        except Exception as e:
            logger.error(f"Error generating embeddings: {str(e)}")
            raise

    def save_best_checkpoint(self, model_path: str = './best_model'):
        """Save the complete system state including tokenizer info"""
        try:
            os.makedirs(model_path, exist_ok=True)

            # Save core models
            self.model.save(model_path)

            # Save tokenizer info
            if self.tokenizer:
                tokenizer_info = {
                    'name_or_path': self.tokenizer.name_or_path,
                    'special_tokens_map': self.tokenizer.special_tokens_map,
                    'init_kwargs': self.tokenizer.init_kwargs
                }
                with open(os.path.join(model_path, 'tokenizer_config.json'), 'w') as f:
                    json.dump(tokenizer_info, f, indent=4)
                self.tokenizer.save_pretrained(model_path)

            # Save data and metadata
            if self.df is not None:
                self.df.to_csv(os.path.join(model_path, 'data.csv'), index=False)

            metadata = {
                'best_params': self.best_params,
                'best_checkpoint': self.best_checkpoint,
                'retrieval_config': {
                    'default_k': getattr(self, 'default_k', 1),
                    'score_threshold': getattr(self, 'score_threshold', 50.0),
                    'max_same_category': getattr(self, 'max_same_category', 2),
                    'tokenizer_type': 'legal-bert' if 'legal-bert' in str(self.tokenizer) else 'bert-base'
                },
                'components': {
                    'model_type': str(type(self.model)),
                    'cross_encoder_type': str(type(self.cross_encoder)) if self.cross_encoder else None,
                    'bm25_initialized': self.bm25 is not None
                },
                'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            }

            with open(os.path.join(model_path, 'metadata.json'), 'w') as f:
                json.dump(metadata, f, indent=4)

        except Exception as e:
            logger.error(f"Error saving complete checkpoint: {str(e)}")
            raise

    def load_best_checkpoint(self, model_path: str = './best_model'):
        """Load complete system state with tokenizer reconstruction"""
        try:
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Checkpoint directory not found at {model_path}")

            # Load metadata first
            metadata_path = os.path.join(model_path, 'metadata.json')
            if not os.path.exists(metadata_path):
                raise FileNotFoundError("Missing metadata file in checkpoint")

            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
                self.best_params = metadata.get('best_params', self.best_params)
                self.best_checkpoint = metadata.get('best_checkpoint', None)

                # Load retrieval config
                retrieval_config = metadata.get('retrieval_config', {})
                self.default_k = retrieval_config.get('default_k', 1)
                self.score_threshold = retrieval_config.get('score_threshold', 50.0)
                self.max_same_category = retrieval_config.get('max_same_category', 2)

                # Verify component compatibility
                components = metadata.get('components', {})
                if 'legal-bert' not in components.get('tokenizer_type', ''):
                    logger.warning("Original tokenizer was not legal-bert - performance may vary")

            # Load model
            self.model = SentenceTransformer(model_path)

            # Reinitialize tokenizer exactly as before
            tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')
            if os.path.exists(tokenizer_config_path):
                with open(tokenizer_config_path) as f:
                    tokenizer_info = json.load(f)
                try:
                    self.tokenizer = AutoTokenizer.from_pretrained(
                        tokenizer_info['name_or_path'],
                        **tokenizer_info['init_kwargs']
                    )
                except:
                    logger.warning("Failed to load original tokenizer, falling back to default")
                    self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
            else:
                logger.warning("No tokenizer config found, initializing default tokenizer")
                self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

            # Load data
            data_path = os.path.join(model_path, 'data.csv')
            if os.path.exists(data_path):
                self.load_data(data_path)
            else:
                logger.warning("No data file found in checkpoint")

            # Recreate BM25 index
            if self.df is not None and len(self.df) > 0 and self.tokenizer:
                tokenized_corpus = [self.tokenizer.tokenize(str(text)) for text in self.df["short_text"]]
                self.bm25 = BM25Okapi(tokenized_corpus)
                logger.info("Recreated BM25 index from loaded data")

            # Regenerate embeddings
            if self.df is not None and len(self.df) > 0:
                self.generate_embeddings()
                logger.info("Regenerated embeddings from loaded data")

            # Verify cross-encoder
            if metadata.get('components', {}).get('cross_encoder_type'):
                try:
                    self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
                except:
                    logger.warning("Failed to reload cross-encoder")

        except Exception as e:
            logger.error(f"Error loading complete checkpoint: {str(e)}")
            raise

    def verify_consistency(self):
        """Verify all components are properly initialized"""
        checks = {
            'model': isinstance(self.model, SentenceTransformer),
            'tokenizer': hasattr(self.tokenizer, 'tokenize'),
            'embeddings': self.embeddings is not None and len(self.embeddings) == len(self.df),
            'bm25': self.bm25 is not None,
            'cross_encoder': self.cross_encoder is None or isinstance(self.cross_encoder, CrossEncoder),
            'parameters': all(k in self.best_params for k in ['semantic_weight', 'ce_weight'])
        }

        if not all(checks.values()):
            logger.error(f"Consistency check failed: {checks}")
            return False

        logger.info("All components verified and consistent")
        return True


    def train_retrieval_model(self, epochs: int = None, batch_size: int = None,
                            learning_rate: float = None) -> None:
        """Train the retrieval model with contrastive learning and error handling"""
        try:
            logger.info("Training retrieval model...")

            # Use best params if none provided
            epochs = epochs or self.best_params['epochs']
            batch_size = batch_size or self.best_params['batch_size']
            learning_rate = learning_rate or self.best_params['learning_rate']

            # Reduce batch size to alleviate memory pressure
            batch_size = min(batch_size, 8)

            # Prepare training examples
            train_examples = []
            for _, row in self.df.iterrows():
                try:
                    # Positive example (same category)
                    same_cat = self.df[self.df["category"] == row["category"]].sample(1)
                    if len(same_cat) > 0:
                        train_examples.append(InputExample(
                            texts=[str(row["short_text"]), str(same_cat.iloc[0]["short_text"])],
                            label=1.0))

                    # Negative example (different category)
                    diff_cat = self.df[self.df["category"] != row["category"]].sample(1)
                    if len(diff_cat) > 0:
                        train_examples.append(InputExample(
                            texts=[str(row["short_text"]), str(diff_cat.iloc[0]["short_text"])],
                            label=0.0))
                except Exception as e:
                    logger.warning(f"Error creating training example for row {row.name}: {str(e)}")
                    continue

            if not train_examples:
                raise ValueError("No valid training examples could be created")

            # Create dataloader
            try:
                train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
                train_loss = losses.CosineSimilarityLoss(self.model)
            except Exception as e:
                logger.error(f"Error creating dataloader: {str(e)}")
                raise

            # Configure training
            warmup_steps = min(100, len(train_dataloader) * epochs // 10)

            # Train the model with progress tracking
            try:
                self.model.fit(
                    train_objectives=[(train_dataloader, train_loss)],
                    epochs=epochs,
                    warmup_steps=warmup_steps,
                    optimizer_params={'lr': learning_rate},
                    show_progress_bar=True
                )

                # Save the final checkpoint info
                self.best_checkpoint = {
                    'epochs': epochs,
                    'batch_size': batch_size,
                    'learning_rate': learning_rate,
                    'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                }

                # After training completes:
                retrieval_system.default_k = 1  # Set your desired default
                retrieval_system.score_threshold = 70.0  # Higher threshold for single result
                retrieval_system.max_same_category = 1  # Strict category limiting

                # Verify before saving
                if not retrieval_system.verify_consistency():
                    raise RuntimeError("System not in consistent state for saving")

                retrieval_system.save_best_checkpoint()

                logger.info("Model training completed successfully")
            except Exception as e:
                logger.error(f"Error during model training: {str(e)}")
                raise

        except Exception as e:
            logger.error(f"Error in train_retrieval_model: {str(e)}")
            raise

    def expand_query(self, query: str) -> str:
        """Enhanced query expansion that better integrates synonyms"""
        try:
            expanded = query.lower()
            tokens = set(re.findall(r"\w+", expanded))

            # Add synonyms for each token found
            for term, synonyms in self.query_expansion_terms.items():
                if term in tokens:
                    expanded += " " + " ".join(synonyms[:5])

            # Also check for multi-word terms
            for term, synonyms in self.query_expansion_terms.items():
                if ' ' in term and term in expanded:
                    expanded += " " + " ".join(synonyms[:3])

            return " ".join(list(set(expanded.split())))
        except Exception as e:
            logger.warning(f"Query expansion failed: {str(e)}")
            return query

    def _is_ordinance_id_query(self, query: str) -> bool:
        """Check if the query is specifically looking for an ordinance ID"""
        # Patterns like "Ordinance 1234" or just "1234"
        return bool(re.match(r'(ordinance\s*)?\d+', query.lower()))

    def _get_ordinance_by_id(self, ordinance_id: str) -> Optional[Dict]:
        """Retrieve a single ordinance by ID if it exists"""
        try:
            # Clean the ID string
            ordinance_id = re.sub(r'[^0-9]', '', ordinance_id)
            if not ordinance_id:
                return None

            # Check if ID exists
            if ordinance_id not in self.available_ordinance_ids:
                return {
                    "ordinance_id": ordinance_id,
                    "category": "N/A",
                    "short_text": f"Ordinance {ordinance_id} data entry is MISSING from our records.",
                    "fines": "N/A",
                    "date_enacted": "N/A",
                    "status": "N/A",
                    "links": "N/A",
                    "confidence": "Not Found",
                    "score": "0.0%"
                }

            # Get the ordinance data
            idx = self.id_to_idx.get(ordinance_id)
            if idx is None or idx >= len(self.df):
                return {
                    "ordinance_id": ordinance_id,
                    "category": "N/A",
                    "short_text": f"Ordinance {ordinance_id} data entry is MISSING from our records.",
                    "fines": "N/A",
                    "date_enacted": "N/A",
                    "status": "N/A",
                    "links": "N/A",
                    "confidence": "Not Found",
                    "score": "0.0%"
                }

            row = self.df.iloc[idx]
            return {
                "ordinance_id": row["ordinance_id"],
                "category": row.get("category", ""),
                "short_text": row["short_text"],
                "fines": row.get("fines", ""),
                "date_enacted": row.get("date_enacted", ""),
                "status": row.get("status", "Status not specified"),
                "links": row.get("links", "Link not available"),
                "confidence": "High",
                "score": "100.0%"
            }
        except Exception as e:
            logger.warning(f"Error retrieving ordinance by ID {ordinance_id}: {str(e)}")
            return {
                "ordinance_id": ordinance_id,
                "category": "N/A",
                "short_text": f"Ordinance {ordinance_id} data entry is MISSING from our records.",
                "fines": "N/A",
                "date_enacted": "N/A",
                "status": "N/A",
                "links": "N/A",
                "confidence": "Error",
                "score": "0.0%"
            }

    def robust_scale(self, arr: np.ndarray) -> np.ndarray:
        """More robust scaling that handles edge cases"""
        try:
            arr = np.array(arr)
            if np.all(arr == arr[0]):  # All values equal
                return np.ones_like(arr) * 0.5  # Return neutral score
            return (arr - np.min(arr)) / (np.ptp(arr) + 1e-6)  # Add small epsilon
        except Exception as e:
            logger.warning(f"Error in robust_scale: {str(e)}")
            return np.zeros_like(arr)

    def calculate_accuracy_score(self, query_emb: np.ndarray, doc_emb: np.ndarray,
                               semantic_score: float, keyword_score: float,
                               ce_score: Optional[float] = None) -> float:
        """Calculate standardized accuracy score for a match"""
        try:
            # Calculate cosine similarity between query and document embeddings
            similarity = cosine_similarity([query_emb], [doc_emb])[0][0]

            # Normalize scores to [0, 1] range
            semantic_norm = (semantic_score + 1) / 2  # Convert from [-1, 1] to [0, 1]
            keyword_norm = keyword_score / np.max(keyword_score) if np.max(keyword_score) > 0 else 0

            # Combine scores with weights
            if ce_score is not None:
                ce_norm = (ce_score + 1) / 2  # Convert from [-1, 1] to [0, 1]
                combined_score = (
                    0.4 * semantic_norm +  # Semantic similarity weight
                    0.3 * keyword_norm +   # Keyword match weight
                    0.3 * ce_norm          # Cross-encoder weight
                )
            else:
                combined_score = (
                    0.6 * semantic_norm +  # Higher weight for semantic when no CE
                    0.4 * keyword_norm     # Lower weight for keyword when no CE
                )

            # Convert to percentage (0-100%) and cap at 100%
            accuracy_score = min(combined_score * 100, 100.0)

            return accuracy_score

        except Exception as e:
            logger.warning(f"Error calculating accuracy score: {str(e)}")
            return 0.0

    def _get_confidence_level(self, score: float) -> str:
        """Convert numeric score to confidence level"""
        try:
            score = float(score)
            if score > 70:  # 70%
                return "High"
            elif score > 50:  # 50%
                return "Medium"
            elif score > 30:  # 30%
                return "Low"
            return "Very Low"
        except:
            return "Unknown"

    def retrieve_ordinances(self, query: str, k: int = 5,
                      semantic_weight: float = None,
                      ce_weight: float = None) -> Union[List[Dict], str]:
        try:
            # Check for ordinance ID query first
            if self._is_ordinance_id_query(query):
                ordinance_id = re.sub(r'[^0-9]', '', query)
                ordinance = self._get_ordinance_by_id(ordinance_id)
                if ordinance:
                    return [ordinance]  # Return as a list to maintain consistent return type
                else:
                    return [{
                        "ordinance_id": ordinance_id,
                        "category": "N/A",
                        "short_text": f"Ordinance {ordinance_id} data entry is MISSING from our records.",
                        "fines": "N/A",
                        "date_enacted": "N/A",
                        "status": "N/A",
                        "links": "N/A",
                        "confidence": "Not Found",
                        "score": "0.0%"
                    }]

            # Use best params if none provided
            semantic_weight = semantic_weight or self.best_params['semantic_weight']
            ce_weight = ce_weight or self.best_params['ce_weight']
            final_ce_weight = ce_weight
            final_semantic_weight = 1 - ce_weight

            # Enhanced query expansion with better integration
            try:
                expanded_query = self.expand_query(query)
                query_embs = self.model.encode([query, expanded_query], convert_to_tensor=True)
                query_emb = torch.mean(query_embs, dim=0).cpu().numpy()

                # Enhanced tokenization for BM25 that includes expanded terms
                tokenized_query = self.tokenizer.tokenize(query.lower())
                expanded_tokens = self.tokenizer.tokenize(expanded_query.lower())
                all_tokens = tokenized_query * 2 + expanded_tokens  # Give original terms 2x weight

                logger.info(f"Original query: {query}")
                logger.info(f"Expanded query: {expanded_query}")
                logger.info(f"All search tokens: {all_tokens}")
            except Exception as e:
                logger.warning(f"Query expansion failed, using original query: {str(e)}")
                query_emb = self.model.encode(query, convert_to_tensor=True).cpu().numpy()
                all_tokens = self.tokenizer.tokenize(query.lower())
            # Step 2: First-stage retrieval with enhanced scoring
            # Semantic similarity with chunk processing for large datasets
            semantic_scores = []
            chunk_size = 500  # Process embeddings in chunks to avoid memory issues
            for i in range(0, len(self.embeddings), chunk_size):
                try:
                    chunk = self.embeddings[i:i + chunk_size]
                    scores = cosine_similarity([query_emb], chunk)[0]
                    semantic_scores.extend(scores)
                except Exception as e:
                    logger.warning(f"Error processing chunk {i//chunk_size}: {str(e)}")
                    semantic_scores.extend([0] * len(chunk))

            semantic_scores = np.array(semantic_scores)
            semantic_scores_norm = self.robust_scale(semantic_scores)

            # Enhanced keyword search with multiple fields
            keyword_scores = np.zeros(len(semantic_scores_norm))
            if self.tokenizer is not None and self.bm25 is not None:
                try:
                    # Use the combined tokens (original + expanded) for BM25
                    text_scores = self.bm25.get_scores(all_tokens)

                    # If available, also search in category field
                    if 'category' in self.df.columns:
                        category_bm25 = BM25Okapi([self.tokenizer.tokenize(str(cat).lower())
                                                for cat in self.df['category']])
                        category_scores = category_bm25.get_scores(all_tokens)
                        keyword_scores = 0.7 * text_scores + 0.3 * category_scores
                    else:
                        keyword_scores = text_scores

                    keyword_scores_norm = self.robust_scale(keyword_scores)
                except Exception as e:
                    logger.warning(f"Keyword search failed: {str(e)}")
                    keyword_scores_norm = np.zeros(len(semantic_scores_norm))
            else:
                logger.warning("Keyword search not available (tokenizer or BM25 not initialized)")
                keyword_scores_norm = np.zeros(len(semantic_scores_norm))

            # Get top candidates with diversity
            try:
                top_indices = np.argsort(semantic_scores)[::-1][:200]  # Wider initial pool
            except Exception as e:
                logger.error(f"Error sorting scores: {str(e)}")
                top_indices = np.arange(len(semantic_scores))[:200]

            # Step 3: Enhanced re-ranking with cross-encoder if available
            ce_scores = None
            if self.cross_encoder:
                try:
                    # Prepare richer context for cross-encoder using expanded query
                    pairs = []
                    for idx in top_indices:
                        row = self.df.iloc[idx]
                        context = f"Category: {row.get('category', '')}. "
                        context += f"Text: {row['short_text']}. "
                        if pd.notna(row.get('fines', None)):
                            context += f"Fines: {row['fines']}. "
                        if pd.notna(row.get('date_enacted', None)):
                            context += f"Enacted: {row['date_enacted']}."
                        pairs.append((expanded_query, context))  # Use expanded query here

                    ce_scores = self.cross_encoder.predict(pairs, show_progress_bar=False)
                except Exception as e:
                    logger.warning(f"Cross-encoder failed: {str(e)}")

            # Calculate accuracy scores for all candidates
            accuracy_scores = []
            for i, idx in enumerate(top_indices):
                try:
                    doc_emb = self.embeddings[idx]
                    ce_score = ce_scores[i] if ce_scores is not None else None
                    accuracy = self.calculate_accuracy_score(
                        query_emb, doc_emb,
                        semantic_scores[idx],
                        keyword_scores[idx] if 'keyword_scores' in locals() else 0,
                        ce_score
                    )
                    accuracy_scores.append((idx, accuracy))
                except Exception as e:
                    logger.warning(f"Error calculating accuracy for index {idx}: {str(e)}")
                    accuracy_scores.append((idx, 0.0))

            # Sort by accuracy score
            accuracy_scores.sort(key=lambda x: x[1], reverse=True)

            # Final ranking with diversity promotion and score threshold
            final_indices = []
            seen_categories = set()
            score_threshold = 30.0  # Minimum accuracy score threshold (50%), now 30 for less lenient

            for idx, score in accuracy_scores:
                if len(final_indices) >= k:
                    break

                # Skip results below threshold
                if score < score_threshold:
                    continue

                row = self.df.iloc[idx]
                category = row.get('category', '')

                # Promote diversity by limiting same-category results
                if category not in seen_categories or len(seen_categories) >= 5:
                    final_indices.append((idx, score))
                    seen_categories.add(category)

            # If no results meet the threshold, return a friendly message
            if not final_indices:
                return [{
                    "ordinance_id": "N/A",
                    "category": "N/A",
                    "short_text": "Sorry, I couldn't find any ordinance related to your question. Please try re-phrasing or adding more keywords.",
                    "fines": "N/A",
                    "date_enacted": "N/A",
                    "status": "N/A",
                    "links": "N/A",
                    "confidence": "No Match",
                    "score": "0.0%"
                }]

            # Prepare detailed results
            results = []
            for idx, accuracy_score in final_indices:
                try:
                    row = self.df.iloc[idx]
                    result = {
                        "ordinance_id": row["ordinance_id"],
                        "category": row.get("category", ""),
                        "short_text": row["short_text"],
                        "fines": row.get("fines", ""),
                        "date_enacted": row.get("date_enacted", ""),
                        "status": row.get("status", "Status not specified"),
                        "links": row.get("links", "Links not specified"),
                        "confidence": self._get_confidence_level(accuracy_score / 100),  # Convert back to [0,1] range
                        "score": f"{accuracy_score:.1f}%",
                        "details": {
                            "accuracy_score": f"{accuracy_score:.1f}%",
                            "semantic_score": f"{semantic_scores[idx]:.3f}",
                            "keyword_score": f"{keyword_scores[idx]:.1f}" if 'keyword_scores' in locals() else "N/A",
                            "cross_encoder_score": f"{ce_scores[i]:.1f}" if ce_scores is not None else "N/A"
                        }
                    }
                    results.append(result)
                except Exception as e:
                    logger.warning(f"Error formatting result {idx}: {str(e)}")
                    continue

            return results if results else [{
                "ordinance_id": "N/A",
                "category": "N/A",
                "short_text": "Sorry, I couldn't find any ordinance related to your question. Please try re-phrasing or adding more keywords",
                "fines": "N/A",
                "date_enacted": "N/A",
                "status": "N/A",
                "confidence": "No Match",
                "score": "0.0%"
            }]

        except Exception as e:
            logger.error(f"Error in retrieve_ordinances: {str(e)}")
            return [{
                "ordinance_id": "N/A",
                "category": "N/A",
                "short_text": "Sorry, I couldn't find any ordinance related to your question. Please try re-phrasing or adding more keywords",
                "fines": "N/A",
                "date_enacted": "N/A",
                "status": "N/A",
                "links": "N/A",
                "confidence": "Error",
                "score": "0.0%"
            }]

    def evaluate(self, test_queries: Dict[str, List[str]],
                semantic_weight: float, ce_weight: float) -> float:
        """Evaluate retrieval performance using nDCG with error handling"""
        try:
            all_ndcg = []
            for query, relevant_ids in test_queries.items():
                try:
                    results = self.retrieve_ordinances(
                        query,
                        k=1,
                        semantic_weight=semantic_weight,
                        ce_weight=ce_weight
                    )

                    # Handle case where retrieve_ordinances returns an error message
                    if isinstance(results, str):
                        logger.warning(f"Evaluation failed for query '{query}': {results}")
                        continue

                    retrieved_ids = [res["ordinance_id"] for res in results]

                    # Create relevance scores (1 for relevant, 0 otherwise)
                    true_relevance = [1 if id in relevant_ids else 0 for id in retrieved_ids]

                    # Skip if no relevant documents found
                    if sum(true_relevance) == 0:
                        continue

                    ideal_relevance = sorted(true_relevance, reverse=True)

                    # Calculate NDCG only if we have more than one document
                    if len(true_relevance) > 1:
                        ndcg = ndcg_score([true_relevance], [ideal_relevance])
                        all_ndcg.append(ndcg)
                except Exception as e:
                    logger.warning(f"Error evaluating query '{query}': {str(e)}")
                    continue

            return np.mean(all_ndcg) if all_ndcg else 0.0

        except Exception as e:
            logger.error(f"Error in evaluate: {str(e)}")
            return 0.0

    def optimize_hyperparameters(self, test_queries: Dict[str, List[str]],
                               n_trials: int = 20) -> None:
            """Optimize hyperparameters using Optuna with enhanced error handling"""
            def objective(trial):
                try:
                    # Suggest hyperparameters
                    params = {
                        'semantic_weight': trial.suggest_float('semantic_weight', 0.4, 0.9),
                        'ce_weight': trial.suggest_float('ce_weight', 0.4, 0.9),
                        'batch_size': trial.suggest_categorical('batch_size', [8, 16, 32]),
                        'learning_rate': trial.suggest_float('learning_rate', 1e-6, 5e-5, log=True),
                        'epochs': trial.suggest_int('epochs', 1, 5)
                    }

                    # Train with these parameters
                    self.train_retrieval_model(
                        epochs=params['epochs'],
                        batch_size=params['batch_size'],
                        learning_rate=params['learning_rate']
                    )

                    # Generate new embeddings after training
                    self.generate_embeddings()

                    # Evaluate performance
                    ndcg = self.evaluate(
                        test_queries,
                        semantic_weight=params['semantic_weight'],
                        ce_weight=params['ce_weight']
                    )

                    return ndcg

                except Exception as e:
                    logger.error(f"Trial failed: {str(e)}")
                    return 0.0  # Return minimum score for failed trials

            try:
                # Run optimization
                study = optuna.create_study(direction='maximize')
                study.optimize(objective, n_trials=n_trials)

                # Store best parameters
                self.best_params.update(study.best_params)
                logger.info(f"Best hyperparameters: {self.best_params}")
                logger.info(f"Best nDCG score: {study.best_value:.4f}")

                # Train final model with best parameters
                self.train_retrieval_model(
                    epochs=self.best_params['epochs'],
                    batch_size=self.best_params['batch_size'],
                    learning_rate=self.best_params['learning_rate']
                )
                self.generate_embeddings()

            except Exception as e:
                logger.error(f"Hyperparameter optimization failed: {str(e)}")
                raise

    def interactive_query_loop(self):
        """Run an interactive query loop with enhanced user experience"""
        print("\n=== Manila City Ordinance Retrieval System ===")

        # Print checkpoint information
        if hasattr(self, 'best_checkpoint'):
            print("Welcome!")

        else:
            print("\n⚠️ Using default model parameters (no optimized checkpoint found)")

        print("\nType 'exit' to quit the program\n")

        while True:
            try:
                query = input("\n🔎 Enter your query (or 'exit' to quit): ").strip()
                if query.lower() == 'exit':
                    break

                if not query:
                    print("Please enter a valid query")
                    continue

                start_time = time.time()
                results = self.retrieve_ordinances(query)
                elapsed = time.time() - start_time

                # Handle case where retrieve_ordinances returns an error message
                if isinstance(results, str):
                    print(f"\n{results}")
                    continue



                for i, res in enumerate(results, 1):
                    print(f"#{i} - ID: {res['ordinance_id']} | Category: {res['category']} | Score: {res['score']}")
                    print(f"📝 Summary: {res['short_text']}")
                    print(f"💰 Fines: {res['fines']}")
                    print(f"🔄 Status: {res['status']}\n")
                    print(f"🔗 Full Text/PDF: {res['links']}")
                    print(f"📅 Date Enacted: {res['date_enacted']}\n")

                if results:  # Only show if there were any results
                  print("\nℹ️ Jayoma Bot is an ordinance retrieval system designed to help users provide clarity to local laws in layman terms")
                  print("and should not be considered a substitute for professional legal advice or official legal interpretation.\n")

            except KeyboardInterrupt:
                print("\nOperation cancelled by user")
                break
            except Exception as e:
                print(f"\nAn error occurred: {str(e)}")
                continue

    def upload_dataset_colab(self) -> str:
        """Handle file upload in Google Colab environment"""
        try:
            from google.colab import files
            uploaded = files.upload()
            if not uploaded:
                raise ValueError("No file was uploaded")
            filename = list(uploaded.keys())[0]
            logger.info(f"Successfully uploaded file: {filename}")
            return filename
        except ImportError:
            logger.error("Google Colab module not found. This function only works in Google Colab environment.")
            raise
        except Exception as e:
            logger.error(f"Error uploading file: {str(e)}")
            raise

# Example usage with comprehensive error handling
if __name__ == "__main__":
    try:
        # Initialize system
        retrieval_system = OrdinanceRetrievalSystem()

        # ====== BEGIN UPDATED WORKFLOW ======
        # First, check if valid checkpoint exists
        checkpoint_path = './best_model'
        checkpoint_valid = os.path.exists(checkpoint_path)

        if checkpoint_valid:
            try:
                print("⏳ Loading existing optimized model...")
                retrieval_system.load_best_checkpoint()
                print("✅ Successfully loaded trained model")
            except Exception as e:
                print(f"⚠️ Failed to load checkpoint (will train new model): {str(e)}")
                checkpoint_valid = False

        if not checkpoint_valid:
            print("\n🆕 Initializing new model training workflow")

            # Load data first
            data_file = None
            try:
                # Check if running in Google Colab
                try:
                    import google.colab
                    is_colab = True
                except ImportError:
                    is_colab = False

                # Load data based on environment
                if is_colab:
                    try:
                        data_file = retrieval_system.upload_dataset_colab()
                    except Exception as e:
                        print(f"Failed to upload file in Colab: {str(e)}")
                        exit(1)
                else:
                    data_file = "ordinance_data.csv"  # Local file
                    if not os.path.exists(data_file):
                        print(f"Error: Data file '{data_file}' not found.")
                        print("Please ensure your CSV file is in the same directory as this script.")
                        exit(1)

                # Load the data
                print("\nLoading data...")
                df = retrieval_system.load_data(data_file)
                if df is None or len(df) == 0:
                    raise ValueError("No data loaded or empty dataset")
                print(f"Successfully loaded {len(df)} ordinances")

            except Exception as e:
                print(f"Failed to load data: {str(e)}")
                exit(1)

            # Initialize models after data is loaded
            try:
                print("\nInitializing models...")
                retrieval_system.initialize_models()
                print("Models initialized successfully")
            except Exception as e:
                print(f"Failed to initialize models: {str(e)}")
                exit(1)

            # Generate embeddings
            try:
                print("\nGenerating embeddings...")
                retrieval_system.generate_embeddings()
                print("Embeddings generated successfully")
            except Exception as e:
                print(f"Failed to generate embeddings: {str(e)}")
                exit(1)

            # Prepare test queries for optimization
            test_queries = {
                "smoking ban": ["8521", "8677", "8563", "8521"],
                "property tax": ["8516", "8503", "8467", "8461", "8454"],
                "business license": ["8814", "8760", "8740"],
                "public health": ["8800", "8797", "8781", "8779"],
                "construction permit": ["8767", "8753", "8738", "8727"]
            }

            # Run hyperparameter optimization
            try:
                print("\nStarting hyperparameter optimization...")
                retrieval_system.optimize_hyperparameters(test_queries, n_trials=9)
                print("Hyperparameter optimization completed")
            except Exception as e:
                print(f"Hyperparameter optimization failed, using defaults: {str(e)}")

        # Run interactive query loop
        retrieval_system.interactive_query_loop()

    except Exception as e:
        print(f"Fatal error: {str(e)}")
        exit(1)