In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Installing packages

In [None]:
!pip install tiktoken

In [None]:
print("=" * 80)
print("INSTALLING REQUIRED PACKAGES FOR KAGGLE")
print("=" * 80)

import subprocess
import sys

def install_package(package):
    """
    Install package using pip if not already installed.
    
    Handles version-specific packages and gracefully manages
    installation failures.
    """
    try:
        # Extract package name from version specifications
        if '>=' in package or '==' in package:
            package_name = package.split('>=')[0].split('==')[0]
            __import__(package_name)
            print("Upgrading {}...".format(package_name))
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "-q", package])
            print("OK: {} upgraded".format(package))
        else:
            __import__(package.split('[')[0])
            print("OK: {} already installed".format(package))
    except ImportError:
        print("Installing {}...".format(package))
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "-q", package])
            print("OK: {} installed".format(package))
        except subprocess.CalledProcessError:
            print("WARNING: Failed to install {}. Continuing...".format(package))
    except Exception as e:
        print("WARNING: Error with {}: {}. Continuing...".format(package, str(e)))

# Core packages required for model optimization
required_packages = [
    "transformers",          # Model loading (has BitsAndBytesConfig)
    "peft",                   # LoRA adapter support
    "accelerate",            # Distributed model loading
    "torch",                  # PyTorch (usually pre-installed)
    "faiss-cpu",                     # Vector search
    "sentence-transformers",         # Embedding models
    "psutil",                        # Memory monitoring
    "safetensors>=0.4.0",           # Safe model serialization
    "bitsandbytes>=0.39.0",          # 8-bit quantization (more stable)
    "llama-cpp-python"
]

print("\nInstalling packages:\n")

for package in required_packages:
    install_package(package)

# ============================================================================
# VERIFICATION OF CRITICAL IMPORTS
# ============================================================================

print("\n" + "=" * 80)
print("VERIFYING CRITICAL IMPORTS")
print("=" * 80 + "\n")

critical_imports = {
    "transformers": "BitsAndBytesConfig",
    "torch": "torch version",
    "peft": "PeftModel",
    "bitsandbytes": "bitsandbytes version",
    "psutil": "memory monitoring"
}

all_imports_successful = True

try:
    import transformers
    print("OK: transformers {}".format(transformers.__version__))
    try:
        from transformers import BitsAndBytesConfig
        print("OK: BitsAndBytesConfig available")
    except ImportError:
        print("WARNING: BitsAndBytesConfig not available")
        all_imports_successful = False
except ImportError as e:
    print("ERROR: Could not import transformers: {}".format(str(e)))
    all_imports_successful = False

try:
    import torch
    print("OK: torch {}".format(torch.__version__))
except ImportError as e:
    print("ERROR: Could not import torch: {}".format(str(e)))
    all_imports_successful = False

try:
    import peft
    print("OK: peft {}".format(peft.__version__))
except ImportError as e:
    print("ERROR: Could not import peft: {}".format(str(e)))
    all_imports_successful = False

try:
    import bitsandbytes
    print("OK: bitsandbytes {}".format(bitsandbytes.__version__))
    from packaging import version
    if version.parse(bitsandbytes.__version__) >= version.parse("0.39.0"):
        print("    Version OK for 8-bit quantization")
    else:
        print("    WARNING: Version {} may have compatibility issues".format(
            bitsandbytes.__version__))
except ImportError as e:
    print("WARNING: Could not import bitsandbytes: {}".format(str(e)))
    print("    8-bit quantization may not work, but other strategies will")

try:
    import psutil
    print("OK: psutil installed")
except ImportError as e:
    print("WARNING: Could not import psutil: {}".format(str(e)))

try:
    import safetensors
    print("OK: safetensors available")
except ImportError as e:
    print("WARNING: Could not import safetensors: {}".format(str(e)))

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "=" * 80)
if all_imports_successful:
    print("ALL PACKAGES INSTALLED AND VERIFIED")
    print("Ready to proceed with model optimization")
else:
    print("PACKAGES INSTALLED WITH WARNINGS")
    print("Core functionality available. Some features may have limitations.")
print("=" * 80 + "\n")

# IMPORTS

In [2]:
# ============================================================================
# IMPORTS
# ============================================================================

import json
import os
import gc
from pathlib import Path

import numpy as np
import torch
import faiss
from tqdm import tqdm

from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

from langchain.text_splitter import RecursiveCharacterTextSplitter
import tiktoken
import json
import pickle
import os
import numpy as np
import faiss
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download, hf_hub_download
import torch
import json
import pickle
import numpy as np
import faiss
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import snapshot_download, hf_hub_download
from sentence_transformers import SentenceTransformer


# CONFIGURATION

In [3]:
# =============================================================================
# LOAD FROM HUGGING FACE - ALL NECESSARY VARIABLES & COMPONENTS
# =============================================================================

print("\n" + "=" * 80)
print("LOADING RAG SYSTEM FROM HUGGING FACE")
print("=" * 80)

# =============================================================================
# STEP 1: CONFIGURATION
# =============================================================================
print("\n[1/5] Configuration Setup")
print("-" * 80)

# Context & Token Configuration
CONTEXT_WINDOW = 1024           # Atlas model context window
PROMPT_TOKENS = 186             # System prompt size
USER_INPUT_TOKENS = 15          # Average query tokens
AVAILABLE_FOR_DOCS = CONTEXT_WINDOW - PROMPT_TOKENS - USER_INPUT_TOKENS

# Chunking Configuration (for reference, chunks already created)
CHUNK_SIZE_TOKENS = 680
CHUNK_OVERLAP_TOKENS = 50
MAX_CHUNKS_PER_DOC = 200
MAX_CHUNKS_FOR_CONTEXT = 3
MAX_TOKENS_PER_RETRIEVAL = AVAILABLE_FOR_DOCS
BATCH_SIZE = 32
MAX_CONTEXT_LENGTH = 1024
# Device Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"✓ Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

print(f"✓ Context Window: {CONTEXT_WINDOW} tokens")
print(f"✓ Available for Documents: {AVAILABLE_FOR_DOCS} tokens")

# Data Path
DATA_PATH_NEW = "/kaggle/input/rag-data3/Rag_daridja_data_merged_cleaned_algerian.json"

# =============================================================================
# STEP 2: LOAD ATLAS MODEL & TOKENIZER
# =============================================================================
print("\n[2/5] Loading Atlas Model & Tokenizer")
print("-" * 80)

# Download model
ATLAS_MERGED_PATH = snapshot_download(
    repo_id="Sally004/Atlas2B_AlgerianDialect_SmokingData"
)
print(f"✓ Downloaded model to: {ATLAS_MERGED_PATH}")

# Fix tokenizer_config.json
config_path = Path(ATLAS_MERGED_PATH) / "tokenizer_config.json"
if config_path.exists():
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    # Remove problematic tokenizer_class field
    if 'tokenizer_class' in config:
        print(f"  Fixing tokenizer_class: {config['tokenizer_class']}")
        del config['tokenizer_class']
    
    # Fix extra_special_tokens if it's a list
    if 'extra_special_tokens' in config and isinstance(config['extra_special_tokens'], list):
        config['extra_special_tokens'] = {}
    
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2)
    print("  ✓ Fixed tokenizer config")

# Load tokenizer
atlas_tokenizer = AutoTokenizer.from_pretrained(
    str(ATLAS_MERGED_PATH),
    trust_remote_code=True
)
print("✓ Atlas tokenizer loaded")

# Load model
atlas_model = AutoModelForCausalLM.from_pretrained(
    str(ATLAS_MERGED_PATH),
    device_map="auto",
    trust_remote_code=True
)
param_count = sum(p.numel() for p in atlas_model.parameters())
print(f"✓ Atlas model loaded")
print(f"  Parameters: {param_count / 1e9:.2f}B")
print(f"  Memory: ~{param_count * 2 / 1e9:.2f} GB (FP16)")

# =============================================================================
# STEP 3: LOAD EMBEDDING MODEL
# =============================================================================
print("\n[3/5] Loading Embedding Model")
print("-" * 80)

EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embedding_model = SentenceTransformer(EMBED_MODEL_NAME)
embedding_dim = embedding_model.get_sentence_embedding_dimension()
print(f"✓ Embedding model loaded: {EMBED_MODEL_NAME}")
print(f"  Embedding dimension: {embedding_dim}")

# =============================================================================
# STEP 4: DOWNLOAD RETRIEVAL FILES FROM HUGGING FACE
# =============================================================================
print("\n[4/5] Downloading Retrieval Files from HF")
print("-" * 80)

REPO_ID = "Sally004/Sai_Dataset_NLP"

INDEX_PATH = hf_hub_download(
    repo_id=REPO_ID,
    filename="retrieval/index/vector_index_token_based.faiss",
    repo_type="dataset"
)
print(f"✓ Downloaded FAISS index")

CHUNKS_PATH = hf_hub_download(
    repo_id=REPO_ID,
    filename="retrieval/chunks/chunks_token_based.pkl",
    repo_type="dataset"
)
print(f"✓ Downloaded chunks")

VECTORS_PATH = hf_hub_download(
    repo_id=REPO_ID,
    filename="retrieval/vectors/embeddings_token_based.npy",
    repo_type="dataset"
)
print(f"✓ Downloaded vectors")

EMBEDDING_INFO_PATH = hf_hub_download(
    repo_id=REPO_ID,
    filename="metadata/embedding_info_token.json",
    repo_type="dataset"
)

CHUNK_STATS_PATH = hf_hub_download(
    repo_id=REPO_ID,
    filename="metadata/chunk_stats_token.json",
    repo_type="dataset"
)

# =============================================================================
# STEP 5: LOAD RETRIEVAL COMPONENTS
# =============================================================================
print("\n[5/5] Loading Retrieval Components")
print("-" * 80)

# Load FAISS index
index = faiss.read_index(INDEX_PATH)
print(f"✓ FAISS index loaded: {index.ntotal} vectors")

# Load chunks
with open(CHUNKS_PATH, "rb") as handle:
    chunked_documents = pickle.load(handle)
print(f"✓ Chunks loaded: {len(chunked_documents)} chunks")

# Load metadata
with open(EMBEDDING_INFO_PATH, "r", encoding="utf-8") as handle:
    emb_info = json.load(handle)

with open(CHUNK_STATS_PATH, "r", encoding="utf-8") as handle:
    chunk_stats = json.load(handle)

print(f"✓ Metadata loaded")
print(f"  Total chunks: {chunk_stats.get('total_chunks', len(chunked_documents))}")
print(f"  Average chunk tokens: {chunk_stats.get('avg_chunk_tokens', 'N/A')}")

# =============================================================================
# HELPER FUNCTIONS
# =============================================================================
print("\n" + "=" * 80)
print("DEFINING HELPER FUNCTIONS")
print("=" * 80)

def count_tokens(text):
    """Count tokens using Atlas tokenizer (fallback to heuristic)."""
    if not text:
        return 0
    try:
        return len(atlas_tokenizer.encode(text, add_special_tokens=False))
    except Exception:
        return max(1, int(len(text.split()) / 0.75))

def count_tokens_exact(text, model_type="atlas"):
    """
    Count tokens exactly using the model's tokenizer.
    Ensures consistency between chunking and generation.
    """
    if not text:
        return 0
    
    try:
        if atlas_tokenizer:
            return len(atlas_tokenizer.encode(text))
        
        # Fallback to tiktoken if available
        import tiktoken
        try:
            encoder = tiktoken.get_encoding("cl100k_base")
            return len(encoder.encode(text))
        except:
            pass
    except Exception:
        pass
    
    # Final fallback: estimate (words / 0.75 for Arabic)
    words = len(text.split())
    return int(words / 0.75)

def embed_query(text):
    """Embed query into normalized vector space."""
    try:
        vec = embedding_model.encode(text, convert_to_numpy=True)
        return (vec / np.linalg.norm(vec)).astype("float32")
    except Exception as exc:
        print(f"✗ Error embedding query: {exc}")
        return None

def retrieve_documents(question, top_k=5, max_tokens=None):
    """Retrieve up to top_k chunks while respecting a token budget."""
    if max_tokens is None:
        max_tokens = MAX_TOKENS_PER_RETRIEVAL
    top_k = min(top_k, MAX_CHUNKS_FOR_CONTEXT)
    query_vec = embed_query(question)
    if query_vec is None:
        return []
    try:
        scores, indices = index.search(np.array([query_vec]), top_k * 2)
    except Exception as exc:
        print(f"✗ FAISS search error: {exc}")
        return []
    results = []
    total_tokens = 0
    for score, idx in zip(scores[0], indices[0]):
        if idx < 0 or idx >= len(chunked_documents):
            continue
        chunk = chunked_documents[idx]
        tokens = chunk["metadata"].get("token_count", count_tokens(chunk["text"]))
        if results and total_tokens + tokens > max_tokens:
            break
        results.append({
            "text": chunk["text"],
            "metadata": chunk["metadata"],
            "score": float(score),
            "tokens": tokens
        })
        total_tokens += tokens
        if len(results) >= top_k:
            break
    return sorted(results, key=lambda item: item["score"], reverse=True)

print("✓ Helper functions defined:")
print("  - count_tokens()")
print("  - count_tokens_exact()")
print("  - embed_query()")
print("  - retrieve_documents()")

# Mobile configuration (not used in this setup)
USE_MOBILE_FOR_INFERENCE = False
ATLAS_MOBILE_PATH = Path("/kaggle/input/atlas-mobile/atlas-2B-merged-darija-Q4_K_M.gguf")

print("\n" + "=" * 80)
print("RAG SYSTEM LOADED SUCCESSFULLY FROM HUGGING FACE!")
print("=" * 80)
print("\nComponents Ready:")
print(f"  ✓ Atlas Model: {param_count / 1e9:.2f}B parameters")
print(f"  ✓ Atlas Tokenizer: Loaded")
print(f"  ✓ Embedding Model: {EMBED_MODEL_NAME}")
print(f"  ✓ FAISS Index: {index.ntotal} vectors")
print(f"  ✓ Document Chunks: {len(chunked_documents)}")
print(f"  ✓ Device: {DEVICE}")
print("\nYou can now run the interactive RAG cell!")
print("=" * 80 + "\n")


LOADING RAG SYSTEM FROM HUGGING FACE

[1/5] Configuration Setup
--------------------------------------------------------------------------------
✓ Device: cuda
  GPU: Tesla T4
  GPU Memory: 15.64 GB
✓ Context Window: 1024 tokens
✓ Available for Documents: 823 tokens

[2/5] Loading Atlas Model & Tokenizer
--------------------------------------------------------------------------------


Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

✓ Downloaded model to: /root/.cache/huggingface/hub/models--Sally004--Atlas2B_AlgerianDialect_SmokingData/snapshots/f1d49f406b4d684b58910bacffcdfea632b50947
  ✓ Fixed tokenizer config
✓ Atlas tokenizer loaded
✓ Atlas model loaded
  Parameters: 2.61B
  Memory: ~5.23 GB (FP16)

[3/5] Loading Embedding Model
--------------------------------------------------------------------------------
✓ Embedding model loaded: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
  Embedding dimension: 384

[4/5] Downloading Retrieval Files from HF
--------------------------------------------------------------------------------
✓ Downloaded FAISS index
✓ Downloaded chunks
✓ Downloaded vectors

[5/5] Loading Retrieval Components
--------------------------------------------------------------------------------
✓ FAISS index loaded: 6142 vectors
✓ Chunks loaded: 6142 chunks
✓ Metadata loaded
  Total chunks: 6142
  Average chunk tokens: 586.3637251709541

DEFINING HELPER FUNCTIONS
✓ Helper functions d

# Interactive Rag

In [None]:
# =============================================================================
# STEP 4: Load fine-tuned Atlas model for inference (Support both models)
# =============================================================================

print("\n" + "=" * 80)
print("STEP 4/4: LOAD ATLAS MODEL FOR INFERENCE")
print("=" * 80 + "\n")

model_type="atlas"
print("✓ Embedding model loaded")

try:
        print("\nLoading model (this may take 1-2 minutes)...")
        param_count = sum(p.numel() for p in atlas_model.parameters())
        print(f"\n✓ Atlas model loaded successfully")
        print(f"  Parameters: {param_count / 1e9:.2f}B")
        print(f"  Memory: ~{param_count * 2 / 1e9:.2f} GB (FP16)")

except Exception as e:
    print(f"Error loading Atlas model: {str(e)}")
    raise


def count_tokens_exact(text, model_type="atlas"):
    """
    Count tokens exactly using the model's tokenizer or tiktoken.
    Ensures consistency between chunking and generation.
    """
    if not text:
        return 0
    
    try:
        if atlas_tokenizer:
            # For transformers models, use their tokenizer
            return len(atlas_tokenizer.encode(text))
        
        # Fallback to tiktoken if available
        import tiktoken
        try:
            encoder = tiktoken.get_encoding("cl100k_base")  # GPT-3.5/4 tokenizer
            return len(encoder.encode(text))
        except:
            pass
        
    except Exception:
        pass
    
    # Final fallback: estimate (words / 0.75 for Arabic)
    words = len(text.split())
    return int(words / 0.75)


def build_rag_prompt_with_token_budget(system_prompt, context, question, max_response_tokens):
    """
    Build RAG prompt while ensuring total tokens < CONTEXT_WINDOW.
    Returns prompt and actual tokens used.
    """
    # Base tokens for structure
    structure_tokens = count_tokens_exact("system: \nuser: السياق:\n\nالسؤال:\n\nassistant: ")
    
    # Calculate available tokens for content
    total_available = CONTEXT_WINDOW - max_response_tokens - structure_tokens
    
    # First, count system prompt tokens
    system_tokens = count_tokens_exact(system_prompt)
    
    # Then count question tokens
    question_tokens = count_tokens_exact(question)
    
    # Calculate available for context
    available_for_context = total_available - system_tokens - question_tokens
    
    if available_for_context <= 0:
        # Not enough tokens even without context
        available_for_context = 50  # Minimum context
    
    # Truncate context to fit
    context_tokens = count_tokens_exact(context)
    if context_tokens > available_for_context:
        # Need to truncate context
        # Estimate characters per token
        if context_tokens > 0:
            chars_per_token = len(context) / context_tokens
            max_chars = int(available_for_context * chars_per_token * 0.9)  # 90% safety
            # Try to cut at sentence boundary
            if max_chars < len(context):
                # Find last space or punctuation
                cut_point = max_chars
                for i in range(max_chars - 1, max(0, max_chars - 100), -1):
                    if context[i] in ['.', '؟', '!', '\n', ' ']:
                        cut_point = i + 1
                        break
                context = context[:cut_point] + "..."
                context_tokens = count_tokens_exact(context)
    
    # Build final prompt
    prompt = (
        f"system: {system_prompt}\n"
        f"user: السياق:\n{context}\n\n"
        f"السؤال:\n{question}\n"
        f"assistant: "
    )
    
    # Calculate total tokens
    total_tokens = (
        structure_tokens + 
        system_tokens + 
        context_tokens + 
        question_tokens
    )
    
    return prompt, total_tokens, context_tokens


def interpret_cosine_similarity(score):
    """
    Interpret cosine similarity score with semantic meaning.
    
    Cosine similarity ranges from 0 to 1:
    1.0 = identical meaning (same vector direction)
    0.0 = completely unrelated (orthogonal vectors)
    
    Returns:
        Interpretation dict with category, angle, and confidence level
    """
    # Calculate angle in degrees for better understanding
    angle = np.degrees(np.arccos(min(max(score, 0.0), 1.0)))
    
    if score >= 0.65:
        category = "VERY STRONG MATCH"
        confidence = "Very High"
        angle_desc = f"{angle:.1f}° (almost identical direction)"
        action = "RAG with high confidence"
    elif score >= 0.5:  # NEW THRESHOLD: 0.78
        category = "STRONG MATCH"
        confidence = "High"
        angle_desc = f"{angle:.1f}° (very close)"
        action = "RAG recommended"
    elif score >= 0.45:
        category = "GOOD MATCH"
        confidence = "Medium-High"
        angle_desc = f"{angle:.1f}° (good similarity)"
        action = "RAG suitable"
    elif score >= 0.35:
        category = "MODERATE MATCH"
        confidence = "Medium"
        angle_desc = f"{angle:.1f}° (somewhat related)"
        action = "Consider RAG"
    elif score >= 0.3:
        category = "WEAK MATCH"
        confidence = "Low"
        angle_desc = f"{angle:.1f}° (weak relation)"
        action = "Model knowledge preferred"
    elif score >= 0.25:
        category = "VERY WEAK MATCH"
        confidence = "Very Low"
        angle_desc = f"{angle:.1f}° (barely related)"
        action = "Model knowledge recommended"
    else:
        category = "POOR MATCH"
        confidence = "None"
        angle_desc = f"{angle:.1f}° (unrelated)"
        action = "Use model knowledge only"
    
    return {
        "category": category,
        "confidence_level": confidence,
        "angle_degrees": round(angle, 1),
        "angle_description": angle_desc,
        "recommended_action": action,
        "score": score
    }


def classify_intent_hybrid(question):
    """
    Hybrid semantic intent classifier: rule-based filter + LLM fallback.
    Detects greetings, insults, off-topic before RAG retrieval.
    
    Returns:
        tuple: (intent_label, canned_response or None)
    """
    question_lower = question.lower()
    
    # Fast rule-based filter for common greetings (90% coverage, 0ms latency)
    greeting_keywords = [
        "سلام", "السلام", "مرحبا", "صباح", "مساء", 
        "شكرا", "شكراً", "أهلا", "اهلا"
        "واش راك",      # How are you?
    "واش حالك",     # How are you?
    "واش كيفك",     # How are you?
    "كيف راك",      # How are you?
    "كيفاش راك",    # How are you?
    "واش الحالة",   # What's up?
    "كيف حالك",     # How are you?
    "كيف الحال",  
    ]
    
    words = question_lower.split()
    if len(words) <= 4:  # Short messages only
        if any(kw in question_lower for kw in greeting_keywords):
            # Verify it's not a smoking question with greeting words
                return "تحية", generate_special_response("تحية", question)
    
    # For everything else, use LLM few-shot classifier
    few_shot_examples = [
    # GREETINGS (Adding slang)
    ("واش الحالة", "تحية"),
    ("صحة خويا", "تحية"),
    ("واش يا البوت", "تحية"),
    ("واش راك داير فيها", "تحية"),
    
    # SMOKING (Adding frustration/emotions)
    ("كرهت حياتي ياخو", "سؤال تدخين"), # Emotional distress is usually about smoking here
    ("غلبتني السيجارة", "سؤال تدخين"),
    ("راني فشلان وتعبان", "سؤال تدخين"),
    ("حاب نتهنى من هاد السم", "سؤال تدخين"),
    
    # ACTUAL INSULTS (Be very specific)
    ("أنت حمار", "سب"),
    ("تفو عليك", "سب"),
    ("يا ولد الحرام", "سب"),
    
    # ACTUAL OFF-TOPIC
    ("كيفاش نطيب اللحم؟", "غير ذي صلة"),
    ("شكون ربح الماتش؟", "غير ذي صلة"),
    ("واش رايك في ميسي؟", "غير ذي صلة")
]
    
    examples_text = "\n".join([f"الرسالة: {q}\nالتصنيف: {c}" for q, c in few_shot_examples])
    
    classifier_prompt = f"""تصنف الرسالة في واحد من: سؤال تدخين / تحية / سب / غير ذي صلة

أمثلة:
{examples_text}

الرسالة: {question}
التصنيف:"""
    
    try:
            inputs = atlas_tokenizer(
                    classifier_prompt,
                    return_tensors="pt",
                    truncation=True,
                    max_length=512
                ).to(DEVICE)
                
            with torch.no_grad():
                    outputs = atlas_model.generate(
                        **inputs,
                        max_new_tokens=15,
                        temperature=0.1,
                        top_p=0.5,
                        do_sample=True,
                        pad_token_id=atlas_tokenizer.eos_token_id
                    )
            
            intent_raw = atlas_tokenizer.decode(outputs[0], skip_special_tokens=True)
            if "التصنيف:" in intent_raw:
                intent_raw = intent_raw.split("التصنيف:")[-1].strip()
        
            # Count label occurrences for robust classification
            scores = {
                "تحية": intent_raw.count("تحية"),
                "سب": intent_raw.count("سب"),
                "غير ذي صلة": intent_raw.count("غير ذي صلة"),
                "سؤال تدخين": intent_raw.count("سؤال تدخين")
            }
            
            top_intent = max(scores, key=scores.get)
            
            if top_intent == "تحية" or scores["تحية"] > 0:
                return "greeting", generate_special_response(top_intent, question)
            elif top_intent == "سب" or scores["سب"] > 0:
                return "insult", "معليش، نحترم الجميع هنا. عندك سؤال حول التدخين؟"
            elif top_intent == "غير ذي صلة" or scores["غير ذي صلة"] > 0:
                return "off_topic", "خاطيني، أنا نجاوب غير على أسئلة التدخين والإقلاع عنه."
            
            return "smoking", None
        
    except Exception as e:
            print(f"Intent classification error: {e}")
            return "smoking", None


def detect_query_type_by_keywords(question):
    """
    Detect query type using keyword matching for greetings only.
    We no longer use keyword detection for offensive language.
    """
    question_lower = question.lower()
    
    # Greeting patterns in Algerian Arabic/Darija only
    greeting_patterns = [
        "السلام", "سلام", "مرحبا", "أهلا", "صباح", "مساء", "مساء الخير", "صباح الخير",
        "اهلا", "مرحبا بيك", "مرحبا بك", "كيف حالك", "كيف الحال", "واش كي",
         "مْرحبا", "صباح النور", "مساء النور", "كي راك", "واش راك",
        "واش اخبارك", "واش الأوضاع", "واش أحوالك", "عليكم السلام", "وعليكم السلام",
        "يا مرحبا", "سلامو عليكوم"
    ]
    
    # Check for greeting keywords
    greeting_score = 0
    for pattern in greeting_patterns:
        if pattern in question_lower:
            greeting_score += 1
    
    # Determine query type based on scores
    if greeting_score > 0:
        return {
            "type": "greeting",
            "detection_method": "keyword_matching",
            "confidence": min(greeting_score / 3, 1.0),  # Scale to 0-1
            "reason": f"Detected {greeting_score} greeting keyword(s)"
        }
    else:
        return {
            "type": "normal",
            "detection_method": "none",
            "confidence": 0.0,
            "reason": "No special keywords detected"
        }


def detect_special_query_type_from_documents(retrieved_documents, question):
    """
    Detect special query types from retrieved documents AND keyword matching.
    This combines both approaches for better accuracy.
    """
    if not retrieved_documents:
        return {
            "type": "normal",
            "detection_method": "none",
            "confidence": 0.0,
            "reason": "No documents retrieved"
        }
    
    # First, check if any retrieved document has special titles
    special_titles = {
        "ترحيب عام واسئلة اجتماعية": "تحية",
        "كلام قبيح أو سب": "سب"
    }
    GREET_THRESHOLD = 0.85
    
    best_score = retrieved_documents[0]["score"] if retrieved_documents else 0
    best_title = retrieved_documents[0]["metadata"].get("title", "N/A") if retrieved_documents else ""
    
    # Check if the best document has a special title
    if best_title in special_titles:
        if best_title == "ترحيب عام واسئلة اجتماعية" and best_score < GREET_THRESHOLD:
            pass
        else:
            # For offensive language, only trust if score >= 0.7
            if special_titles[best_title] == "offensive" and best_score < 0.7:
                # Offensive document but low similarity, might be false positive
                return {
                    "type": "normal",
                    "detection_method": "document_title_but_low_score",
                    "confidence": best_score,
                    "reason": f"Offensive document found but low similarity ({best_score:.4f} < 0.7)"
                }
            
            return {
                "type": special_titles[best_title],
                "detection_method": "document_title",
                "confidence": best_score,
                "reason": f"Document title indicates {special_titles[best_title]} type"
            }
    
    # If not found in titles, check all retrieved documents for offensive
    offensive_detected = False
    offensive_score = 0
    for doc in retrieved_documents[:3]:  # Check top 3
        title = doc["metadata"].get("title", "")
        if title == "كلام قبيح أو سب":
            offensive_detected = True
            offensive_score = doc["score"]
            if doc["score"] >= 0.7:
                return {
                    "type": "offensive",
                    "detection_method": "document_title_in_top_3",
                    "confidence": doc["score"],
                    "reason": f"Offensive document found with high similarity ({doc['score']:.4f} ≥ 0.7)"
                }
    
    # If offensive detected but score too low, note it
    if offensive_detected:
        return {
            "type": "normal",
            "detection_method": "document_title_but_low_score",
            "confidence": offensive_score,
            "reason": f"Offensive document found but low similarity ({offensive_score:.4f} < 0.7)"
        }
    
    # If still not found, use keyword matching for greetings only
    keyword_result = detect_query_type_by_keywords(question)
    
    # Only trust keyword detection for greetings if confidence is high enough
    if keyword_result["type"] == "greeting" and keyword_result["confidence"] > 0.3:
        return keyword_result
    
    return {
        "type": "normal",
        "detection_method": "none",
        "confidence": 0.0,
        "reason": "Normal smoking-related query"
    }

    
def embed_query_simple(text):
    """Embed text using pre-loaded model."""
    if not text or embedding_model is None:
        return None
    
    try:
        embedding = embedding_model.encode(text, convert_to_numpy=True)
        embedding = embedding / np.linalg.norm(embedding)
        return embedding.astype("float32")
    except Exception as e:
        print(f"Embedding error: {e}")
        return None
        
def select_greeting_by_similarity(user_question):
    """Select most appropriate greeting based on semantic similarity using anchor phrases."""
    
    # 1. Map 'Anchors' (what user says) to 'Responses' (what bot says)
    # This ensures high similarity scores
    greeting_map = [
        {"anchors": ["السلام عليكم", "سلام"], "response": "وعليكم السلام ورحمة الله! مرحبا بيك، قولي واش هو سؤالك على التدخين؟"},
        {"anchors": ["أهلا", "مرحبا", "واش راك"], "response": "أهلا بيك! واش راك؟ كيفاش نقدر نعاونك اليوم في موضوع التدخين؟"},
        {"anchors": ["صباح الخير", "كي صبحت"], "response": "صباح النور والسرور! واش راك؟ كاش ما نقدر نعاونك في موضوع التدخين اليوم؟"},
        {"anchors": ["مساء الخير", "كي عشيت"], "response": "مساء الخير والأنوار! واش أحوالك؟ راني هنا إذا سحقيت كاش نصيحة على التدخين"}
    ]
    
    # Extract just the first anchor for embedding (or a representative phrase)
    anchor_phrases = [m["anchors"][0] for m in greeting_map]
    responses = [m["response"] for m in greeting_map]

    # 2. Embed user question
    user_embedding = embed_query_simple(user_question)
    if user_embedding is None:
        return responses[1], 0.0 # Default to general 'Ahlan'
    
    # 3. Embed the Anchors (Note: In production, do this ONCE outside the function to save speed)
    greeting_embeddings = []
    for anchor in anchor_phrases:
        emb = embed_query_simple(anchor)
        if emb is not None:
            greeting_embeddings.append(emb)
        else:
            # Fallback to zero vector if embedding fails
            greeting_embeddings.append(np.zeros_like(user_embedding))
    
    # 4. Calculate similarities
    greeting_embeddings = np.array(greeting_embeddings)
    
    # Ensure vectors are normalized for true cosine similarity
    # similarity = dot(A, B) / (norm(A) * norm(B))
    similarities = np.dot(greeting_embeddings, user_embedding)
    
    # 5. Find best match
    best_idx = np.argmax(similarities)
    best_similarity = float(similarities[best_idx])
    selected_greeting = responses[best_idx]
    
    return selected_greeting, best_similarity
    
def generate_special_response(query_type_info, question):
    """
    Generate response for special query types (greetings, offensive language).
    """
    if isinstance(query_type_info, str):
        q_type = query_type_info
        if query_type_info == "تحية":
            selected, similarity = select_greeting_by_similarity(question)
            return selected
        elif query_type_info == "سب":
            # Direct refusal for offensive language in Algerian Darija
            return "معليش، ما نقدرش نجاوب على هاد النوع من الكلام. تكلم باحترام ونقدر نعاونك."
    else:
        q_type = query_type_info.get("type")
        if query_type_info["type"] == "تحية":
            # Simple greeting response in Algerian Darija
            selected, similarity = select_greeting_by_similarity(question)
            return selected
        elif query_type_info["type"] == "سب" :
            # Direct refusal for offensive language in Algerian Darija
            return "معليش، ما نقدرش نجاوب على هاد النوع من الكلام. تكلم باحترام ونقدر نعاونك."
    return None


def generate_answer(question, context=None, max_tokens_for_response=120, use_rag=True, temperature=None):
    """
    Generate answer using fine-tuned Atlas model with anti-hallucination measures.
    
    Improvements:
    - Chain-of-verification prompt (hidden reasoning steps)
    - Lower temperature for RAG (0.2 vs 0.35)
    - Lower repetition penalty (1.2 vs 1.5)
    - Dynamic token allocation (80-150 based on confidence)
    
    Args:
        question: User's question (string)
        context: Retrieved context from RAG (string) or None if using model knowledge
        max_tokens_for_response: Max tokens to generate
        use_rag: Whether to use RAG context or model knowledge
        temperature: Specific temperature to use (if None, uses default based on use_rag)
    """
    
    # Ensure max_tokens_for_response is reasonable
    
    if use_rag and context:
        # RAG with anti-hallucination prompt
        system_prompt = (
        "أنت خبير جزائري مختص في التوعية ضد التدخين. "
        "تحدث بالدارجة الجزائرية البيضاء (فصيحة تقنياً). "
        "التزم بالحقائق العلمية فقط. ممنوع الدراما، ممنوع السياسة، وممنوع التحدث عن دول أخرى .جاوب مباشرة تقنياً. ابدأ الجواب بـ 'بناءً على المعلومات المتوفرة"
        "خاطب المستخدم بصيغة المذكر دائماً إلا إذا ذكر عكس ذلك."
    )
        # Anti-hallucination: Add hidden reasoning steps
        prompt = f"""system: {system_prompt}

{context if context else "ملاحظة: لا يوجد سياق خارجي، جاوب باختصار من القواعد العامة للإقلاع عن التدخين."}

التعليمات الإجبارية:
1. استخرج النصائح من السياق (Context) إذا كان متوفراً.
2. إذا كان السؤال بعيداً عن التدخين، قل: "أنا هنا للمساعدة في الإقلاع عن التدخين فقط".
3. ممنوع نهائياً ذكر أي جمل درامية غير واقعية.
4. الجواب يكون في شكل نقاط (Bullet points) ليكون واضحاً.
5. لا تزد عن 80 كلمة لتجنب انقطاع النص.

user: السؤال: {question}

assistant: """
        
        actual_prompt_tokens = count_tokens(prompt)
        max_safe = CONTEXT_WINDOW - actual_prompt_tokens - 10
        max_tokens_for_response = min(max_tokens_for_response, max_safe)
        # Lower temperature for factual responses
        if temperature is None:
            temperature = 0.2  # Was 0.35
        
        # Calculate tokens for budget tracking
        prompt_tokens = count_tokens_exact(prompt)
        
    else:
        # Model knowledge (no RAG)
        system_prompt = (
            "أنت مساعد جزائري مختص في التدخين والإقلاع عنه، تهدر بالدارجة الجزائرية. "
            "جاوب بإجابات قصيرة ومباشرة وعملية بلا خطبة. ممنوع تمد نصائح طبية من راسك. ممنوع الفلسفة."
        )
        
        prompt = f"""system: {system_prompt}

user: {question}

assistant: """
        
        if temperature is None:
            temperature = 0.4  # Was 0.7
        
        prompt_tokens = count_tokens_exact(prompt)
    
    try:
        inputs = atlas_tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=MAX_CONTEXT_LENGTH
            ).to(DEVICE)
            
        with torch.no_grad():
                outputs = atlas_model.generate(
                    **inputs,
                    max_new_tokens=max_tokens_for_response,
                    temperature=temperature,
                    top_p=0.4,              # Was 0.8
                    top_k=40,
                    do_sample=True,
                    repetition_penalty=1.2,  # Was 1.15
                    pad_token_id=atlas_tokenizer.eos_token_id,
                    eos_token_id=atlas_tokenizer.eos_token_id
                )
            
                answer = atlas_tokenizer.decode(outputs[0], skip_special_tokens=True)
                
                if "assistant:" in answer:
                    answer = answer.split("assistant:")[-1].strip()
        
        # Return answer and token usage info
        answer_tokens = count_tokens_exact(answer)
        return {
            "answer": answer,
            "total_prompt_tokens": prompt_tokens,
            "answer_tokens": answer_tokens,
            "context_tokens_used": count_tokens_exact(context) if context else 0,
            "total_tokens": prompt_tokens + answer_tokens
        }
    
    except Exception as e:
        print(f"Error generating answer: {str(e)}")
        return {
            "answer": "معليش، وقعت مشكلة في الجواب. حاول تاني بعد شوية.",
            "total_prompt_tokens": 0,
            "answer_tokens": 0,
            "context_tokens_used": 0,
            "total_tokens": 0
        }


def process_query(question):
    """
    Complete RAG pipeline with semantic intent classification and anti-hallucination.
    
    Pipeline:
    1. Pre-RAG intent classification (greetings/off-topic/insults)
    2. Document retrieval (only for smoking questions)
    3. Dynamic token allocation (80-150 based on confidence)
    4. Anti-hallucination generation
    
    Updated for token-based retrieval with budget enforcement.
    """
    
    # STEP 1: Pre-RAG Intent Classification
    intent, canned_response = classify_intent_hybrid(question)
    
    # If non-smoking intent detected, return canned response immediately
    if canned_response:
        return {
            "answer": canned_response,
            "confidence": 1.0,
            "selected_document": None,
            "all_documents": [],
            "context_tokens": 0,
            "available_for_response": 0,
            "rag_used": False,
            "query_type": intent,
            "reason": f"Intent classification: {intent} (handled pre-RAG)",
            "similarity_interpretation": None,
            "special_handling": True,
            "token_usage": {
                "prompt_tokens": count_tokens_exact(f"system: \nuser: {question}\nassistant: "),
                "answer_tokens": count_tokens_exact(canned_response),
                "total_tokens": count_tokens_exact(f"system: \nuser: {question}\nassistant: {canned_response}"),
                "context_window": CONTEXT_WINDOW
            }
        }
    
    # STEP 2: Calculate max tokens for context retrieval
    system_prompt_tokens = count_tokens_exact(
        "أنت مساعد جزائري مختص في التدخين والإقلاع عنه، تهدر بالدارجة الجزائرية. "
        "جاوب بإجابات قصيرة ومباشرة وعملية بلا خطبة."
    )
    
    question_tokens = count_tokens_exact(question)
    structure_tokens = count_tokens_exact("system: \nuser: السياق:\n\nالسؤال:\n\nassistant: ")
    
    # Reserve tokens for response 
    response_tokens_reserved = 0
    
    # Calculate available for context
    max_context_tokens = CONTEXT_WINDOW - (
        system_prompt_tokens + 
        question_tokens + 
        structure_tokens + 
        response_tokens_reserved
    )
    
    # Ensure we have at least some context
    if max_context_tokens < 50:
        max_context_tokens = 50
    
    # STEP 3: Retrieve documents
    retrieved = retrieve_documents(
        question, 
        top_k=5,
        max_tokens=max_context_tokens
    )
    
    if not retrieved:
        # No documents found, use model knowledge
        result = generate_answer(question, context=None, max_tokens_for_response=80, use_rag=False)
        return {
            "answer": result["answer"],
            "confidence": 0.0,
            "selected_document": None,
            "all_documents": [],
            "context_tokens": 0,
            "available_for_response": AVAILABLE_FOR_DOCS,
            "rag_used": False,
            "query_type": "smoking",
            "reason": "No relevant documents found",
            "similarity_interpretation": interpret_cosine_similarity(0.0),
            "token_usage": {
                "prompt_tokens": result.get("total_prompt_tokens", 0),
                "answer_tokens": result.get("answer_tokens", 0),
                "total_tokens": result.get("total_tokens", 0),
                "context_window": CONTEXT_WINDOW
            }
        }
    
    best_doc = retrieved[0]
    best_score = best_doc["score"]
    best_title = best_doc["metadata"].get("title", "N/A")
    context = best_doc["text"]
    
    # Check for special document-based detection (offensive language)
    query_type_info = detect_special_query_type_from_documents(retrieved, question)
    
    # If it's a special query type, generate appropriate response
    if query_type_info["type"] in ["تحية", "سب"]:
        answer = generate_special_response(query_type_info, question)
        
        return {
            "answer": answer,
            "confidence": round(best_score, 4),
            "selected_document": {
                "rank": 1,
                "title": best_title,
                "score": best_score,
                "tokens": best_doc.get("tokens", 0)
            },
            "all_documents": [
                {
                    "rank": i + 1,
                    "title": doc["metadata"].get("title", "N/A"),
                    "score": round(doc["score"], 4),
                    "tokens": doc.get("tokens", 0)
                }
                for i, doc in enumerate(retrieved)
            ],
            "context_tokens": 0,
            "available_for_response": AVAILABLE_FOR_DOCS,
            "rag_used": False,
            "query_type": query_type_info["type"],
            "detection_method": query_type_info["detection_method"],
            "reason": f"{query_type_info['reason']} (confidence: {query_type_info['confidence']:.2f})",
            "similarity_interpretation": interpret_cosine_similarity(best_score),
            "special_handling": True,
            "token_usage": {
                "prompt_tokens": count_tokens_exact(f"system: \nuser: {question}\nassistant: "),
                "answer_tokens": count_tokens_exact(answer),
                "total_tokens": count_tokens_exact(f"system: \nuser: {question}\nassistant: {answer}"),
                "context_window": CONTEXT_WINDOW
            }
        }
    
    # Get interpretation of the similarity score
    similarity_interpretation = interpret_cosine_similarity(best_score)
    
    # STEP 4: Dynamic token allocation based on confidence
    CONFIDENCE_THRESHOLD = 0.65
    
    if best_score >= 0.77:
        max_tokens_for_response = 250
        temperature = 0.32
    elif best_score >= CONFIDENCE_THRESHOLD:
        max_tokens_for_response = 180
        temperature = 0.4
    elif best_score >= 0.45:
        max_tokens_for_response = 150
        temperature = 0.45
        # ← Move HERE: Always concatenate all docs first
        context_parts = []
        for doc in retrieved:
            context_parts.append(f"[{doc['metadata']['title']}]\n{doc['text']}")
        context = "\n---\n".join(context_parts)
        
        # Calculate context tokens and available for response
        context_tokens_used = sum(doc["tokens"] for doc in retrieved)
        available_for_response = CONTEXT_WINDOW - (
            system_prompt_tokens + 
            question_tokens + 
            structure_tokens + 
            context_tokens_used
        )
        
        if available_for_response < 100:
            available_for_response = 100
            max_context_tokens = CONTEXT_WINDOW - (
                system_prompt_tokens + 
                question_tokens + 
                structure_tokens + 
                available_for_response
            )
            retrieved = retrieve_documents(question, top_k=5, max_tokens=max_context_tokens)
            if retrieved:
                best_doc = retrieved[0]
                context_parts = []
                for doc in retrieved:
                    context_parts.append(f"[{doc['metadata']['title']}]\n{doc['text']}")
                context = "\n---\n".join(context_parts)
                context_tokens_used = sum(doc["tokens"] for doc in retrieved)
        
        # Use RAG with medium confidence
        result = generate_answer(question, context, max_tokens_for_response=max_tokens_for_response, 
                               use_rag=True, temperature=temperature)
        
        return {
            "answer": result["answer"],
            "confidence": round(best_score, 4),
            "selected_document": {
                "rank": 1,
                "title": best_title,
                "score": best_score,
                "tokens": best_doc.get("tokens", 0)
            },
            "all_documents": [
                {
                    "rank": i + 1,
                    "title": doc["metadata"].get("title", "N/A"),
                    "score": round(doc["score"], 4),
                    "tokens": doc.get("tokens", 0)
                }
                for i, doc in enumerate(retrieved)
            ],
            "context_tokens": context_tokens_used,
            "available_for_response": max_tokens_for_response,
            "rag_used": True,
            "query_type": "smoking",
            "reason": f"Good match RAG ({best_score:.4f} ≥ 0.70)",
            "similarity_interpretation": similarity_interpretation,
            "generation_temperature": temperature,
            "token_usage": {
                "prompt_tokens": result.get("total_prompt_tokens", 0),
                "answer_tokens": result.get("answer_tokens", 0),
                "context_tokens": result.get("context_tokens_used", 0),
                "total_tokens": result.get("total_tokens", 0),
                "context_window": CONTEXT_WINDOW
            }
        }
    else:
        # Confidence too low, use model knowledge
        if best_score >= 0.3:
            temperature = 0.45
            max_tokens = 80
        else:
            temperature = 0.5
            max_tokens = 70
        
        result = generate_answer(question, context=None, max_tokens_for_response=max_tokens, 
                               use_rag=False, temperature=temperature)
        
        return {
            "answer": result["answer"],
            "confidence": round(best_score, 4),
            "selected_document": {
                "rank": 1,
                "title": best_title,
                "score": best_score,
                "tokens": best_doc.get("tokens", 0)
            },
            "all_documents": [
                {
                    "rank": i + 1,
                    "title": doc["metadata"].get("title", "N/A"),
                    "score": round(doc["score"], 4),
                    "tokens": doc.get("tokens", 0)
                }
                for i, doc in enumerate(retrieved)
            ],
            "context_tokens": 0,
            "available_for_response": AVAILABLE_FOR_DOCS,
            "rag_used": False,
            "query_type": "smoking",
            "reason": f"Cosine similarity {best_score:.4f} < {CONFIDENCE_THRESHOLD} (threshold)",
            "similarity_interpretation": similarity_interpretation,
            "generation_temperature": temperature,
            "model_knowledge_tokens": max_tokens,
            "token_usage": {
                "prompt_tokens": result.get("total_prompt_tokens", 0),
                "answer_tokens": result.get("answer_tokens", 0),
                "total_tokens": result.get("total_tokens", 0),
                "context_window": CONTEXT_WINDOW
            }
        }
    
    # High confidence RAG path (>= 0.78)
    context_tokens_used = sum(doc["tokens"] for doc in retrieved)
    
    # Calculate available for response
    available_for_response = CONTEXT_WINDOW - (
        system_prompt_tokens + 
        question_tokens + 
        structure_tokens + 
        context_tokens_used
    )
    
    # Ensure we have at least minimum tokens for response
    if available_for_response < max_tokens_for_response:
        max_tokens_for_response = max(55, min(max_tokens_for_response, available_for_response))
        if available_for_response < 55:
            # Need to reduce context
            available_for_response = 55
            max_context_tokens = CONTEXT_WINDOW - (
                system_prompt_tokens + 
                question_tokens + 
                structure_tokens + 
                available_for_response
            )
            retrieved = retrieve_documents(question, top_k=5, max_tokens=max_context_tokens)
            if retrieved:
                best_doc = retrieved[0]
                context_parts = []
                for doc in retrieved:
                    context_parts.append(f"[{doc['metadata']['title']}]\n{doc['text']}")
                context = "\n---\n".join(context_parts)
                context_tokens_used = sum(doc["tokens"] for doc in retrieved)
    
    result = generate_answer(question, context, max_tokens_for_response=max_tokens_for_response, 
                           use_rag=True, temperature=temperature)
    
    return {
        "answer": result["answer"],
        "confidence": round(best_score, 4),
        "selected_document": {
            "rank": 1,
            "title": best_title,
            "score": best_score,
            "tokens": best_doc.get("tokens", 0)
        },
        "all_documents": [
            {
                "rank": i + 1,
                "title": doc["metadata"].get("title", "N/A"),
                "score": round(doc["score"], 4),
                "tokens": doc.get("tokens", 0)
            }
            for i, doc in enumerate(retrieved)
        ],
        "context_tokens": context_tokens_used,
        "available_for_response": max_tokens_for_response,
        "rag_used": True,
        "query_type": "smoking",
        "reason": f"High confidence RAG ({best_score:.4f} ≥ {CONFIDENCE_THRESHOLD})",
        "similarity_interpretation": similarity_interpretation,
        "generation_temperature": temperature,
        "token_usage": {
            "prompt_tokens": result.get("total_prompt_tokens", 0),
            "answer_tokens": result.get("answer_tokens", 0),
            "context_tokens": result.get("context_tokens_used", 0),
            "total_tokens": result.get("total_tokens", 0),
            "context_window": CONTEXT_WINDOW
        }
    }

# =============================================================================
# INTERACTIVE RAG SYSTEM
# =============================================================================

print("\n" + "=" * 80)
print("RAG SYSTEM READY FOR QUERIES!")
print("=" * 80)
print("\nSystem Architecture:")
print("   Input: Rag_daridja_data_ar_only.json (token-based chunked)")
print("   Embedding: paraphrase-multilingual-MiniLM → 384-dim vectors")
print("   Search: FAISS cosine similarity (normalized inner product)")
print("   Generation: Fine-tuned Atlas model (using", model_type + ")")
print("\nKey Improvements:")
print("   Pre-RAG Intent Classification: Hybrid rule-based + LLM few-shot")
print("   Threshold: 0.65 for RAG utilization (strong match or better)")
print("   RAG Answer Tokens: 100-150 tokens for detailed smoking answers")
print("   Anti-Hallucination: Chain-of-verification prompts with hidden reasoning")
print("   Generation Params: temp=0.2-0.4 (RAG), repetition_penalty=1.2")
print("\nIntent Handling:")
print("   Greetings: Detected pre-RAG, clean response (no smoking hallucination)")
print("   Off-topic: Polite rejection, no elaboration")
print("   Insults: Respectful deflection, redirect to smoking questions")
print("   Offensive Detection: Via 'كلام قبيح أو سب' doc with score ≥ 0.7")
print("\nToken Budget Enforcement:")
print(f"   Context Window: {CONTEXT_WINDOW} tokens (hard limit)")
print(f"   Exact token counting for system prompt + context + question + response")
print(f"   Automatic truncation to fit within {CONTEXT_WINDOW} tokens")
print("=" * 80 + "\n")

def run_interactive_session():
    """
    Interactive RAG chatbot with complete intent classification and anti-hallucination.
    """
    print("Interactive RAG System - Algerian Darija Smoking Cessation Assistant")
    print("   Type 'exit' or 'quit' to stop\n")
    
    query_count = 0
    
    while True:
        try:
            user_input = input("\nYour question (Arabic/Darija): ").strip()
            
            if user_input.lower() in ["exit", "quit", "خروج"]:
                print(f"\n✓ Session ended. Total queries: {query_count}")
                print("Goodbye / بسلامة")
                break
            
            if not user_input:
                continue
            
            query_count += 1
            print(f"\n{'=' * 80}")
            print(f"QUERY #{query_count}")
            print('=' * 80)
            
            result = process_query(user_input)
            
            # Check if special handling (greeting, off-topic, insult)
            if result.get("special_handling", False):
                print(f"\nINTENT CLASSIFICATION (Pre-RAG)")
                print("-" * 80)
                print(f"  Detected Type: {result['query_type'].upper()}")
                print(f"  Reason: {result['reason']}")
                print(f"  Retrieval: SKIPPED (handled before RAG)")
            else:
                print("\nRETRIEVED DOCUMENTS (Top 5, ranked by cosine similarity)")
                print("-" * 80)
                for doc in result["all_documents"]:
                    marker = " ✓ SELECTED" if doc["rank"] == 1 else ""
                    print(f"  [{doc['rank']}] {doc['title']}{marker}")
                    print(f"      Score: {doc['score']:.4f} | Tokens: {doc['tokens']}")
                
                if result["selected_document"]:
                    print(f"\nCOSINE SIMILARITY ANALYSIS")
                    print("-" * 80)
                    print(f"  Top Document: {result['selected_document']['title']}")
                    print(f"  Cosine Similarity Score: {result['confidence']:.4f}")
                    
                    if "similarity_interpretation" in result and result["similarity_interpretation"]:
                        interp = result["similarity_interpretation"]
                        print(f"  Category: {interp['category']}")
                        print(f"  Vector Angle: {interp['angle_description']}")
                        print(f"  Confidence Level: {interp['confidence_level']}")
                        print(f"  Recommended Action: {interp['recommended_action']}")
                    
                    print(f"\nDECISION & GENERATION SETTINGS")
                    print("-" * 80)
                    print(f"  Threshold: 0.65 (STRONG MATCH or better)")
                    print(f"  Decision: {'RAG (context used)' if result['rag_used'] else 'Model Knowledge (no context)'}")
                    print(f"  Reason: {result['reason']}")
                    
                    if result["rag_used"]:
                        temp = result.get('generation_temperature', 0.25)
                        print(f"  Generation: Temperature={temp}, Max Tokens={result['available_for_response']}")
                        print(f"  Total Context Tokens Used: {result['context_tokens']} tokens")
                    else:
                        temp = result.get('generation_temperature', 0.5)
                        tokens = result.get('model_knowledge_tokens', 80)
                        print(f"  Generation: Temperature={temp}, Max Tokens={tokens}")
                        print(f"  Reasoning: Cosine similarity below threshold, using model knowledge")
            
            # Show token usage
            if "token_usage" in result:
                print(f"\nTOKEN USAGE (Context Window: {CONTEXT_WINDOW} tokens)")
                print("-" * 80)
                usage = result["token_usage"]
                print(f"  Prompt Tokens: {usage.get('prompt_tokens', 'N/A')}")
                if usage.get('context_tokens', 0) > 0:
                    print(f"  Context Tokens: {usage['context_tokens']}")
                print(f"  Answer Tokens: {usage.get('answer_tokens', 'N/A')}")
                print(f"  Total Tokens: {usage.get('total_tokens', 'N/A')}")
                
                total = usage.get('total_tokens', 0)
                if total > CONTEXT_WINDOW:
                    print(f"  ⚠ WARNING: Exceeded context window by {total - CONTEXT_WINDOW} tokens!")
                else:
                    remaining = CONTEXT_WINDOW - total
                    print(f"  Remaining: {remaining} tokens ({remaining/CONTEXT_WINDOW:.1%} free)")
            
            print("\nGENERATED ANSWER")
            print("-" * 80)
            print(result["answer"])
            print("-" * 80)
        
        except KeyboardInterrupt:
            print(f"\n\n⚠ Session terminated by user (Ctrl+C)")
            print(f"Total queries processed: {query_count}")
            break
        except Exception as e:
            print(f"Error processing query: {str(e)}")
            print("Please try again or type 'exit' to quit.")
            continue

In [None]:
# session run using mobile atlas
if __name__ == "__main__":
    run_interactive_session()

# NGROK SETUP

In [4]:
!pip install fastapi uvicorn pyngrok -q

In [None]:
from pyngrok import ngrok

ngrok.set_auth_token("your_ngrok_auth_token_here")

public_url = ngrok.connect(8000)

print(public_url)


NgrokTunnel: "https://nikolas-interfilar-stalagmitically.ngrok-free.dev" -> "http://localhost:8000"


In [None]:
import subprocess
import time
import requests
from pathlib import Path

print("\n" + "="*80)
print("STARTING PRODUCTION API + NGROK (EXACT RAG LOGIC)")
print("="*80)

# Save the corrected API code with actual path values injected
api_code = '''"""
Production RAG API for Kaggle + ngrok
Uses EXACT logic from your fine-tuned Atlas model with all improvements
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import numpy as np
from sentence_transformers import SentenceTransformer
import pickle
import json
from pathlib import Path
import faiss
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
import sys

# Configuration - Match your RAG code exactly
EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
CONTEXT_WINDOW = 1024
CONFIDENCE_THRESHOLD = 0.65  # UPDATED from 0.78
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_CONTEXT_LENGTH = 512

# Paths injected from notebook variables
INDEX_PATH = "{index_path}"
CHUNKS_PATH = "{chunks_path}"
ATLAS_MERGED_PATH = "{atlas_merged_path}"

app = FastAPI(title="Atlas RAG API", version="2.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global state
embed_model = None
atlas_model = None
atlas_tokenizer = None
index = None
chunked_documents = []

@app.on_event("startup")
async def startup_event():
    global embed_model, atlas_model, atlas_tokenizer, index, chunked_documents
    
    print("\\n" + "="*80, flush=True)
    print("INITIALIZING RAG SYSTEM", flush=True)
    print("="*80, flush=True)
    
    # Load embedding model
    print("\\n[1/3] Loading embedding model...", flush=True)
    try:
        embed_model = SentenceTransformer(EMBED_MODEL_NAME)
        print("✓ Embedding model ready", flush=True)
    except Exception as e:
        print(f"✗ Error: {{e}}", flush=True)
    
    # Load FAISS index
    print("[2/3] Loading FAISS index...", flush=True)
    try:
        index_path = INDEX_PATH
        chunks_path = CHUNKS_PATH
        
        if Path(index_path).exists() and Path(chunks_path).exists():
            index = faiss.read_index(index_path)
            with open(chunks_path, "rb") as f:
                chunked_documents = pickle.load(f)
            print(f"✓ FAISS index loaded: {{index.ntotal}} vectors", flush=True)
            print(f"✓ Chunks loaded: {{len(chunked_documents)}} documents", flush=True)
        else:
            print("⚠ FAISS files not found", flush=True)
    except Exception as e:
        print(f"✗ Error: {{e}}", flush=True)
    
    # Load Atlas model (FP16)
    print("[3/3] Loading Atlas model (FP16)...", flush=True)
    try:
        if ATLAS_MERGED_PATH and ATLAS_MERGED_PATH != "":
            print(f"Loading from: {{ATLAS_MERGED_PATH}}", flush=True)
            atlas_model = AutoModelForCausalLM.from_pretrained(
                str(ATLAS_MERGED_PATH),
                device_map="auto",
                trust_remote_code=True
            )
            
            atlas_tokenizer = AutoTokenizer.from_pretrained(
                str(ATLAS_MERGED_PATH),
                trust_remote_code=True
            )
            
            if atlas_tokenizer.pad_token is None:
                atlas_tokenizer.pad_token = atlas_tokenizer.eos_token
            
            print("✓ Atlas model loaded (FP16)", flush=True)
        else:
            print(f"⚠ Model path not provided: {{ATLAS_MERGED_PATH}}", flush=True)
    except Exception as e:
        print(f"✗ Error loading model: {{e}}", flush=True)
        import traceback
        traceback.print_exc()
    
    sys.stdout.flush()

def count_tokens_exact(text, model_type="atlas"):
    """Count tokens exactly - match your RAG code"""
    if not text:
        return 0
    
    try:
        if atlas_tokenizer:
            return len(atlas_tokenizer.encode(text))
        
        import tiktoken
        try:
            encoder = tiktoken.get_encoding("cl100k_base")
            return len(encoder.encode(text))
        except:
            pass
    except Exception:
        pass
    
    words = len(text.split())
    return int(words / 0.75)

def embed_query_simple(text):
    """Simple embedding function - match your RAG code"""
    try:
        embedding = embed_model.encode(text, convert_to_numpy=True)
        embedding = embedding / np.linalg.norm(embedding)
        return embedding.astype("float32")
    except Exception:
        return None

def retrieve_documents(question, top_k=5, max_tokens=None):
    """Retrieve documents with token budget - match your RAG code"""
    if index is None or not chunked_documents:
        return []
    
    if max_tokens is None:
        max_tokens = 500
    
    query_vec = embed_query_simple(question)
    if query_vec is None:
        return []
    
    try:
        scores, indices = index.search(np.array([query_vec]), min(top_k * 2, len(chunked_documents)))
    except:
        return []
    
    results = []
    total_tokens = 0
    
    for score, idx in zip(scores[0], indices[0]):
        if idx < 0 or idx >= len(chunked_documents):
            continue
        
        chunk = chunked_documents[idx]
        tokens = chunk["metadata"].get("token_count", count_tokens_exact(chunk["text"]))
        
        if results and total_tokens + tokens > max_tokens:
            break
        
        results.append({
            "text": chunk["text"],
            "metadata": chunk["metadata"],
            "score": float(score),
            "tokens": tokens
        })
        total_tokens += tokens
        
        if len(results) >= top_k:
            break
    
    return sorted(results, key=lambda x: x["score"], reverse=True)

def interpret_cosine_similarity(score):
    """UPDATED interpretation with new thresholds - match your RAG code exactly"""
    angle = np.degrees(np.arccos(min(max(score, 0.0), 1.0)))
    
    if score >= 0.65:
        category = "VERY STRONG MATCH"
        confidence = "Very High"
        angle_desc = f"{angle:.1f}° (almost identical direction)"
        action = "RAG with high confidence"
    elif score >= 0.5:  # NEW THRESHOLD
        category = "STRONG MATCH"
        confidence = "High"
        angle_desc = f"{angle:.1f}° (very close)"
        action = "RAG recommended"
    elif score >= 0.45:
        category = "GOOD MATCH"
        confidence = "Medium-High"
        angle_desc = f"{angle:.1f}° (good similarity)"
        action = "RAG suitable"
    elif score >= 0.35:
        category = "MODERATE MATCH"
        confidence = "Medium"
        angle_desc = f"{angle:.1f}° (somewhat related)"
        action = "Consider RAG"
    elif score >= 0.3:
        category = "WEAK MATCH"
        confidence = "Low"
        angle_desc = f"{angle:.1f}° (weak relation)"
        action = "Model knowledge preferred"
    elif score >= 0.25:
        category = "VERY WEAK MATCH"
        confidence = "Very Low"
        angle_desc = f"{angle:.1f}° (barely related)"
        action = "Model knowledge recommended"
    else:
        category = "POOR MATCH"
        confidence = "None"
        angle_desc = f"{angle:.1f}° (unrelated)"
        action = "Use model knowledge only"
    
    return {
        "category": category,
        "confidence_level": confidence,
        "angle_degrees": round(angle, 1),
        "angle_description": angle_desc,
        "recommended_action": action,
        "score": score
    }

def classify_intent_hybrid(question):
    """UPDATED with new examples and logic - match your RAG code"""
    question_lower = question.lower()
    
    # Fast rule-based filter
    greeting_keywords = [
        "سلام", "السلام", "مرحبا", "صباح", "مساء", 
        "شكرا", "شكراً", "أهلا", "اهلا", "واش راك",
        "واش حالك", "واش كيفك", "كيف راك", "كيفاش راك",
        "واش الحالة", "كيف حالك", "كيف الحال"
    ]
    
    words = question_lower.split()
    if len(words) <= 4:  # UPDATED from 3 to 4
        if any(kw in question_lower for kw in greeting_keywords):
            return "تحية", generate_special_response("تحية", question)
    
    # UPDATED few-shot examples with new patterns
    few_shot_examples = [
        # GREETINGS (Adding slang)
        ("واش الحالة", "تحية"),
        ("صحة خويا", "تحية"),
        ("واش يا البوت", "تحية"),
        ("واش راك داير فيها", "تحية"),
        
        # SMOKING (Adding frustration/emotions)
        ("كرهت حياتي ياخو", "سؤال تدخين"),
        ("غلبتني السيجارة", "سؤال تدخين"),
        ("راني فشلان وتعبان", "سؤال تدخين"),
        ("حاب نتهنى من هاد السم", "سؤال تدخين"),
        
        # ACTUAL INSULTS (Be very specific)
        ("أنت حمار", "سب"),
        ("تفو عليك", "سب"),
        ("يا ولد الحرام", "سب"),
        
        # ACTUAL OFF-TOPIC
        ("كيفاش نطيب اللحم؟", "غير ذي صلة"),
        ("شكون ربح الماتش؟", "غير ذي صلة"),
        ("واش رايك في ميسي؟", "غير ذي صلة")
    ]
    
    examples_text = "\\n".join([f"الرسالة: {q}\\nالتصنيف: {c}" for q, c in few_shot_examples])
    
    classifier_prompt = f"""تصنف الرسالة في واحد من: سؤال تدخين / تحية / سب / غير ذي صلة

أمثلة:
{examples_text}

الرسالة: {question}
التصنيف:"""
    
    try:
        inputs = atlas_tokenizer(
            classifier_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(DEVICE)
        
        with torch.no_grad():
            outputs = atlas_model.generate(
                **inputs,
                max_new_tokens=15,
                temperature=0.1,
                top_p=0.5,
                do_sample=True,
                pad_token_id=atlas_tokenizer.eos_token_id
            )
        
        intent_raw = atlas_tokenizer.decode(outputs[0], skip_special_tokens=True)
        if "التصنيف:" in intent_raw:
            intent_raw = intent_raw.split("التصنيف:")[-1].strip()
        
        scores = {
            "تحية": intent_raw.count("تحية"),
            "سب": intent_raw.count("سب"),
            "غير ذي صلة": intent_raw.count("غير ذي صلة"),
            "سؤال تدخين": intent_raw.count("سؤال تدخين")
        }
        
        top_intent = max(scores, key=scores.get)
        
        if top_intent == "تحية" or scores["تحية"] > 0:
            return "greeting", generate_special_response(top_intent, question)
        elif top_intent == "سب" or scores["سب"] > 0:
            return "insult", "معليش، نحترم الجميع هنا. عندك سؤال حول التدخين؟"
        elif top_intent == "غير ذي صلة" or scores["غير ذي صلة"] > 0:
            return "off_topic", "خاطيني، أنا نجاوب غير على أسئلة التدخين والإقلاع عنه."
        
        return "smoking", None
    
    except Exception as e:
        print(f"Intent classification error: {e}", flush=True)
        return "smoking", None

def select_greeting_by_similarity(user_question):
    """Select greeting by similarity - match your RAG code"""
    greeting_map = [
        {"anchors": ["السلام عليكم", "سلام"], "response": "وعليكم السلام ورحمة الله! مرحبا بيك، قولي واش هو سؤالك على التدخين؟"},
        {"anchors": ["أهلا", "مرحبا", "واش راك"], "response": "أهلا بيك! واش راك؟ كيفاش نقدر نعاونك اليوم في موضوع التدخين؟"},
        {"anchors": ["صباح الخير", "كي صبحت"], "response": "صباح النور والسرور! واش راك؟ كاش ما نقدر نعاونك في موضوع التدخين اليوم؟"},
        {"anchors": ["مساء الخير", "كي عشيت"], "response": "مساء الخير والأنوار! واش أحوالك؟ راني هنا إذا سحقيت كاش نصيحة على التدخين"}
    ]
    
    anchor_phrases = [m["anchors"][0] for m in greeting_map]
    responses = [m["response"] for m in greeting_map]

    user_embedding = embed_query_simple(user_question)
    if user_embedding is None:
        return responses[1], 0.0
    
    greeting_embeddings = []
    for anchor in anchor_phrases:
        emb = embed_query_simple(anchor)
        if emb is not None:
            greeting_embeddings.append(emb)
        else:
            greeting_embeddings.append(np.zeros_like(user_embedding))
    
    greeting_embeddings = np.array(greeting_embeddings)
    similarities = np.dot(greeting_embeddings, user_embedding)
    
    best_idx = np.argmax(similarities)
    best_similarity = float(similarities[best_idx])
    selected_greeting = responses[best_idx]
    
    return selected_greeting, best_similarity

def generate_special_response(query_type_info, question):
    """Generate special response - match your RAG code"""
    if isinstance(query_type_info, str):
        q_type = query_type_info
        if query_type_info == "تحية":
            selected, similarity = select_greeting_by_similarity(question)
            return selected
        elif query_type_info == "سب":
            return "معليش، ما نقدرش نجاوب على هاد النوع من الكلام. تكلم باحترام ونقدر نعاونك."
    else:
        q_type = query_type_info.get("type")
        if query_type_info["type"] == "تحية":
            selected, similarity = select_greeting_by_similarity(question)
            return selected
        elif query_type_info["type"] == "سب":
            return "معليش، ما نقدرش نجاوب على هاد النوع من الكلام. تكلم باحترام ونقدر نعاونك."
    return None

def clean_model_output(answer):
    """
    Clean model output to remove any training markers like user:, assistant:, system:
    This fixes the issue where model echoes back its training format.
    """
    answer = answer.strip()
    
    # Pattern 1: Remove "user:" or "assistant:" prefixes
    answer = re.sub(r'^(user:|assistant:|system:)\\s*', '', answer, flags=re.IGNORECASE)
    
    # Pattern 2: If there's "A: " at the start (from our prompt), remove it
    if answer.startswith('A: '):
        answer = answer[3:].strip()
    
    # Pattern 3: Remove any conversation format that leaked through
    markers = ['\\nuser:', '\\nassistant:', '\\nsystem:', '\\nA:', '\\nالسياق:', '\\nالسؤال:']
    for marker in markers:
        if marker in answer:
            answer = answer.split(marker)[0].strip()
    
    # Pattern 4: If answer still has system/user/assistant anywhere, extract after last marker
    if any(word in answer.lower() for word in ['user:', 'assistant:', 'system:']):
        parts = re.split(r'(user:|assistant:|system:|A:)', answer, flags=re.IGNORECASE)
        if len(parts) > 1:
            for part in reversed(parts):
                if part.strip() and not any(m in part.lower() for m in ['user:', 'assistant:', 'system:']):
                    answer = part.strip()
                    break
    
    return answer.strip()

def generate_answer(question, context=None, max_tokens_for_response=120, use_rag=True, temperature=None):
    """UPDATED generate_answer with anti-hallucination measures - match your RAG code exactly"""
    
    if use_rag and context:
        # UPDATED with anti-hallucination prompt
        system_prompt = (
            "أنت خبير جزائري مختص في التوعية ضد التدخين. "
            "تحدث بالدارجة الجزائرية البيضاء (فصيحة تقنياً). "
            "التزم بالحقائق العلمية فقط. ممنوع الدراما، ممنوع السياسة، وممنوع التحدث عن دول أخرى .جاوب مباشرة تقنياً. ابدأ الجواب بـ 'بناءً على المعلومات المتوفرة'"
            "خاطب المستخدم بصيغة المذكر دائماً إلا إذا ذكر عكس ذلك."
        )
        # UPDATED: Anti-hallucination with hidden reasoning
        prompt = f"""system: {system_prompt}

{context if context else "ملاحظة: لا يوجد سياق خارجي، جاوب باختصار من القواعد العامة للإقلاع عن التدخين."}

التعليمات الإجبارية:
1. استخرج النصائح من السياق (Context) إذا كان متوفراً.
2. إذا كان السؤال بعيداً عن التدخين، قل: "أنا هنا للمساعدة في الإقلاع عن التدخين فقط".
3. ممنوع نهائياً ذكر أي جمل درامية غير واقعية.
4. الجواب يكون في شكل نقاط (Bullet points) ليكون واضحاً.
5. لا تزد عن 80 كلمة لتجنب انقطاع النص.

user: السؤال: {question}

A: """
        
        if temperature is None:
            temperature = 0.2  # UPDATED from 0.35
    else:
        system_prompt = (
            "أنت مساعد جزائري مختص في التدخين والإقلاع عنه، تهدر بالدارجة الجزائرية. "
            "جاوب بإجابات قصيرة ومباشرة وعملية بلا خطبة. ممنوع تمد نصائح طبية من راسك. ممنوع الفلسفة."
        )
        
        prompt = f"""system: {system_prompt}

user: {question}

A: """
        
        if temperature is None:
            temperature = 0.4  # UPDATED from 0.7
    
    try:
        inputs = atlas_tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=MAX_CONTEXT_LENGTH
        ).to(DEVICE)
        
        prompt_length = inputs["input_ids"].shape[1]
        
        with torch.no_grad():
            outputs = atlas_model.generate(
                **inputs,
                max_new_tokens=max_tokens_for_response,
                temperature=temperature,
                top_p=0.4,              # UPDATED from 0.8
                top_k=30,               # UPDATED from 40
                do_sample=True,
                repetition_penalty=1.2,  # UPDATED from 1.15
                pad_token_id=atlas_tokenizer.eos_token_id,
                eos_token_id=atlas_tokenizer.eos_token_id
            )
        
        # Extract only the new tokens generated (not the prompt)
        generated_tokens = outputs[0][prompt_length:]
        raw_answer = atlas_tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
        
        # Clean the output to remove any training format artifacts
        answer = clean_model_output(raw_answer)
        
        prompt_tokens = count_tokens_exact(prompt)
        answer_tokens = count_tokens_exact(answer)
        
        return {
            "answer": answer,
            "total_prompt_tokens": prompt_tokens,
            "answer_tokens": answer_tokens,
            "context_tokens_used": count_tokens_exact(context) if context else 0,
            "total_tokens": prompt_tokens + answer_tokens
        }
    
    except Exception as e:
        print(f"Error generating answer: {str(e)}", flush=True)
        import traceback
        traceback.print_exc()
        return {
            "answer": "معليش، وقعت مشكلة في الجواب. حاول تاني بعد شوية.",
            "total_prompt_tokens": 0,
            "answer_tokens": 0,
            "context_tokens_used": 0,
            "total_tokens": 0
        }

def process_query(question):
    """UPDATED complete RAG pipeline with new thresholds - match your RAG code exactly"""
    
    # STEP 1: Pre-RAG Intent Classification
    intent, canned_response = classify_intent_hybrid(question)
    
    if canned_response:
        return {
            "answer": canned_response,
            "confidence": 1.0,
            "selected_document": None,
            "all_documents": [],
            "context_tokens": 0,
            "rag_used": False,
            "query_type": intent,
            "reason": f"Intent classification: {intent} (handled pre-RAG)",
            "similarity_interpretation": None,
            "special_handling": True,
            "token_usage": {
                "prompt_tokens": count_tokens_exact(f"system: \\nuser: {question}\\nassistant: "),
                "answer_tokens": count_tokens_exact(canned_response),
                "total_tokens": count_tokens_exact(f"system: \\nuser: {question}\\nassistant: {canned_response}"),
                "context_window": CONTEXT_WINDOW
            }
        }
    
    # STEP 2: Retrieve documents
    retrieved = retrieve_documents(question, top_k=5, max_tokens=500)
    
    if not retrieved:
        result = generate_answer(question, context=None, max_tokens_for_response=80, use_rag=False)
        return {
            "answer": result["answer"],
            "confidence": 0.0,
            "selected_document": None,
            "all_documents": [],
            "context_tokens": 0,
            "rag_used": False,
            "query_type": "smoking",
            "reason": "No relevant documents found",
            "similarity_interpretation": interpret_cosine_similarity(0.0),
            "token_usage": {
                "prompt_tokens": result.get("total_prompt_tokens", 0),
                "answer_tokens": result.get("answer_tokens", 0),
                "total_tokens": result.get("total_tokens", 0),
                "context_window": CONTEXT_WINDOW
            }
        }
    
    best_doc = retrieved[0]
    best_score = best_doc["score"]
    best_title = best_doc["metadata"].get("title", "N/A")
    context = best_doc["text"]
    
    similarity_interpretation = interpret_cosine_similarity(best_score)
    
    # STEP 3: UPDATED dynamic token allocation with new thresholds
    if best_score >= 0.77:  # NEW THRESHOLD
        max_tokens_for_response = 250
        temperature = 0.32
    elif best_score >= CONFIDENCE_THRESHOLD:  # 0.65
        max_tokens_for_response = 180
        temperature = 0.4
    elif best_score >= 0.45:  # NEW THRESHOLD
        max_tokens_for_response = 150
        temperature = 0.45
        # Concatenate all docs
        context_parts = []
        for doc in retrieved:
            context_parts.append(f"[{doc['metadata']['title']}]\\n{doc['text']}")
        context = "\\n---\\n".join(context_parts)
        
        context_tokens_used = sum(doc["tokens"] for doc in retrieved)
        
        result = generate_answer(question, context, max_tokens_for_response=max_tokens_for_response, 
                               use_rag=True, temperature=temperature)
        
        return {
            "answer": result["answer"],
            "confidence": round(best_score, 4),
            "selected_document": {
                "rank": 1,
                "title": best_title,
                "score": best_score,
                "tokens": best_doc.get("tokens", 0)
            },
            "all_documents": [
                {
                    "rank": i + 1,
                    "title": doc["metadata"].get("title", "N/A"),
                    "score": round(doc["score"], 4),
                    "tokens": doc.get("tokens", 0)
                }
                for i, doc in enumerate(retrieved)
            ],
            "context_tokens": context_tokens_used,
            "available_for_response": max_tokens_for_response,
            "rag_used": True,
            "query_type": "smoking",
            "reason": f"Good match RAG ({best_score:.4f} ≥ 0.45)",
            "similarity_interpretation": similarity_interpretation,
            "generation_temperature": temperature,
            "token_usage": {
                "prompt_tokens": result.get("total_prompt_tokens", 0),
                "answer_tokens": result.get("answer_tokens", 0),
                "context_tokens": result.get("context_tokens_used", 0),
                "total_tokens": result.get("total_tokens", 0),
                "context_window": CONTEXT_WINDOW
            }
        }
    else:
        # Confidence too low, use model knowledge with UPDATED parameters
        if best_score >= 0.3:
            temperature = 0.45
            max_tokens = 80
        else:
            temperature = 0.5
            max_tokens = 70
        
        result = generate_answer(question, context=None, max_tokens_for_response=max_tokens, 
                               use_rag=False, temperature=temperature)
        
        return {
            "answer": result["answer"],
            "confidence": round(best_score, 4),
            "selected_document": {
                "rank": 1,
                "title": best_title,
                "score": best_score,
                "tokens": best_doc.get("tokens", 0)
            },
            "all_documents": [
                {
                    "rank": i + 1,
                    "title": doc["metadata"].get("title", "N/A"),
                    "score": round(doc["score"], 4),
                    "tokens": doc.get("tokens", 0)
                }
                for i, doc in enumerate(retrieved)
            ],
            "context_tokens": 0,
            "rag_used": False,
            "query_type": "smoking",
            "reason": f"Cosine similarity {best_score:.4f} < {CONFIDENCE_THRESHOLD} (threshold)",
            "similarity_interpretation": similarity_interpretation,
            "generation_temperature": temperature,
            "model_knowledge_tokens": max_tokens,
            "token_usage": {
                "prompt_tokens": result.get("total_prompt_tokens", 0),
                "answer_tokens": result.get("answer_tokens", 0),
                "total_tokens": result.get("total_tokens", 0),
                "context_window": CONTEXT_WINDOW
            }
        }
    
    # High confidence RAG path (>= 0.77)
    result = generate_answer(question, context, max_tokens_for_response=max_tokens_for_response, 
                           use_rag=True, temperature=temperature)
    
    return {
        "answer": result["answer"],
        "confidence": round(best_score, 4),
        "selected_document": {
            "rank": 1,
            "title": best_title,
            "score": best_score,
            "tokens": best_doc.get("tokens", 0)
        },
        "all_documents": [
            {
                "rank": i + 1,
                "title": doc["metadata"].get("title", "N/A"),
                "score": round(doc["score"], 4),
                "tokens": doc.get("tokens", 0)
            }
            for i, doc in enumerate(retrieved)
        ],
        "context_tokens": result.get("context_tokens_used", 0),
        "available_for_response": max_tokens_for_response,
        "rag_used": True,
        "query_type": "smoking",
        "reason": f"High confidence RAG ({best_score:.4f} ≥ {CONFIDENCE_THRESHOLD})",
        "similarity_interpretation": similarity_interpretation,
        "generation_temperature": temperature,
        "token_usage": {
            "prompt_tokens": result.get("total_prompt_tokens", 0),
            "answer_tokens": result.get("answer_tokens", 0),
            "context_tokens": result.get("context_tokens_used", 0),
            "total_tokens": result.get("total_tokens", 0),
            "context_window": CONTEXT_WINDOW
        }
    }

class QueryRequest(BaseModel):
    query: str

@app.get("/")
async def root():
    return {"status": "online", "service": "Atlas RAG API v2.0"}

@app.get("/health")
async def health():
    return {
        "status": "healthy", 
        "documents": len(chunked_documents), 
        "model_loaded": atlas_model is not None,
        "tokenizer_loaded": atlas_tokenizer is not None,
        "index_loaded": index is not None,
        "model": "Atlas 2B FP16",
        "threshold": CONFIDENCE_THRESHOLD,
        "context_window": CONTEXT_WINDOW
    }

@app.post("/query")
async def query(request: QueryRequest):
    try:
        return process_query(request.query)
    except Exception as e:
        import traceback
        print(f"Error in /query endpoint: {str(e)}", flush=True)
        traceback.print_exc()
        return {"error": str(e), "answer": "معليش، صار خلل. حاول مرة أخرى."}

@app.get("/info")
async def info():
    return {
        "model": "Atlas 2B FP16",
        "documents": len(chunked_documents),
        "threshold": CONFIDENCE_THRESHOLD,
        "context_window": CONTEXT_WINDOW,
        "version": "2.0.0 (All improvements included)",
        "model_loaded": atlas_model is not None,
        "components": {
            "embedding": embed_model is not None,
            "faiss": index is not None,
            "atlas_model": atlas_model is not None,
            "tokenizer": atlas_tokenizer is not None
        }
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
'''

# Inject the actual path values using string replacement
api_code_with_paths = api_code.replace("{index_path}", INDEX_PATH)
api_code_with_paths = api_code_with_paths.replace("{chunks_path}", CHUNKS_PATH)
api_code_with_paths = api_code_with_paths.replace("{atlas_merged_path}", ATLAS_MERGED_PATH)

# Write API code
with open("production_api.py", "w") as f:
    f.write(api_code_with_paths)

print("✓ API code created (EXACT RAG LOGIC WITH ALL IMPROVEMENTS)")

# Start FastAPI WITHOUT piping stdout
print("\n[1/2] Starting FastAPI server...")
api_process = subprocess.Popen(
    ["python", "-u", "production_api.py"],  # -u for unbuffered output
    stdout=None,  # Don't pipe - let output show
    stderr=None
)

time.sleep(15)  # Give more time for model loading

# Expose with ngrok
print("[2/2] Exposing with ngrok...")
try:
    public_url = ngrok.connect(8000)
    print(f"\n" + "="*80)
    print("✅ API IS LIVE WITH ALL IMPROVEMENTS!")
    print("="*80)
    print(f"\nAPI URL: {public_url}")
    print(f"Query: {public_url}/query")
    print(f"Health: {public_url}/health")
    print(f"Info: {public_url}/info")
    
    # Test health endpoint
    print("\n" + "="*80)
    print("VERIFYING SYSTEM STATUS")
    print("="*80)
    try:
        health_response = requests.get(f"{public_url}/health", timeout=10)
        health_data = health_response.json()
        print(f"✓ Health check successful:")
        print(f"  Documents loaded: {health_data.get('documents', 0)}")
        print(f"  Model loaded: {health_data.get('model_loaded', False)}")
        print(f"  Tokenizer loaded: {health_data.get('tokenizer_loaded', False)}")
        print(f"  Index loaded: {health_data.get('index_loaded', False)}")
        
        if not health_data.get('model_loaded'):
            print("\n⚠️ WARNING: Atlas model failed to load!")
        if health_data.get('documents', 0) == 0:
            print("\n⚠️ WARNING: No documents loaded!")
    except Exception as e:
        print(f"⚠️ Health check failed: {e}")
    
    print("\n" + "="*80)
    print("KEY IMPROVEMENTS ACTIVE:")
    print("="*80)
    print("✓ Threshold: 0.65 (was 0.78)")
    print("✓ Dynamic tokens: 70-250 based on confidence")
    print("✓ Anti-hallucination prompts with hidden reasoning")
    print("✓ Temperature: 0.2-0.5 (lower for better accuracy)")
    print("✓ Updated intent classification with 14 examples")
    print("✓ Greeting similarity matching")
    print("✓ Clean output (removes 'user:', 'assistant:' markers)")
    print("✓ New thresholds: 0.77, 0.65, 0.45, 0.3")
    print("="*80)
    
    print("\n" + "-"*80)
    print("Test with Postman or curl:")
    print(f"POST {public_url}/query")
    print('Body: {"query": "واش علاش مهم نقلع عن التدخين؟"}')
    print("-"*80)
    
    # Keep running
    print("\nAPI running (Ctrl+C to stop)...")
    try:
        while True:
            time.sleep(30)
    except KeyboardInterrupt:
        print("\nStopping...")
        ngrok.disconnect(public_url)
        api_process.terminate()
        print("✓ Stopped")

except Exception as e:
    print(f"✗ Error: {e}")
    api_process.terminate()


STARTING PRODUCTION API + NGROK (EXACT RAG LOGIC)
✓ API code created (EXACT RAG LOGIC WITH ALL IMPROVEMENTS)

[1/2] Starting FastAPI server...


2026-02-05 13:41:37.503092: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770298897.525344     887 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770298897.531841     887 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770298897.548058     887 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770298897.548085     887 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770298897.548089     887 computation_placer.cc:177] computation placer alr


INITIALIZING RAG SYSTEM

[1/3] Loading embedding model...
✓ Embedding model ready
[2/3] Loading FAISS index...
✓ FAISS index loaded: {index.ntotal} vectors
✓ Chunks loaded: {len(chunked_documents)} documents
[3/3] Loading Atlas model (FP16)...
Loading from: {ATLAS_MERGED_PATH}
[2/2] Exposing with ngrok...

✅ API IS LIVE WITH ALL IMPROVEMENTS!

API URL: NgrokTunnel: "https://nikolas-interfilar-stalagmitically.ngrok-free.dev" -> "http://localhost:8000"
Query: NgrokTunnel: "https://nikolas-interfilar-stalagmitically.ngrok-free.dev" -> "http://localhost:8000"/query
Health: NgrokTunnel: "https://nikolas-interfilar-stalagmitically.ngrok-free.dev" -> "http://localhost:8000"/health
Info: NgrokTunnel: "https://nikolas-interfilar-stalagmitically.ngrok-free.dev" -> "http://localhost:8000"/info

VERIFYING SYSTEM STATUS
⚠️ Health check failed: No connection adapters were found for 'NgrokTunnel: "https://nikolas-interfilar-stalagmitically.ngrok-free.dev" -> "http://localhost:8000"/health'

KEY IMPR

INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


INFO:     129.45.69.222:0 - "GET /health HTTP/1.1" 200 OK
INFO:     129.45.69.222:0 - "POST /query HTTP/1.1" 200 OK
INFO:     129.45.69.222:0 - "POST /query HTTP/1.1" 200 OK
