In [None]:
import json
import faiss
import numpy as np
import re
import logging
import torch
from typing import List, Tuple, Dict, Optional
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langdetect import detect

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DataLoader:
    def __init__(self, json_path: str):
        self.json_path = json_path
        self.qa_pairs: List[Tuple[str, str]] = []
        self.instructions: List[str] = []

    def load(self) -> Tuple[List[Tuple[str, str]], List[str]]:
        try:
            with open(self.json_path, "r", encoding="utf-8") as f:
                data = json.load(f)
            self._parse_json(data)
            return self.qa_pairs, self.instructions
        except (FileNotFoundError, json.JSONDecodeError) as e:
            logger.error(f"Error loading JSON data: {e}")
            raise

    def _parse_json(self, data: dict) -> None:
        for category, content in data.items():
            try:
                if isinstance(content, list):
                    for entry in content:
                        if isinstance(entry, dict):
                            question = entry.get("Q") or entry.get("question")
                            answer = entry.get("A") or entry.get("answer")
                            if isinstance(answer, list):
                                answer = " ".join(answer)
                            if question and answer:
                                self.qa_pairs.append((question, answer))
                        elif isinstance(entry, list):
                            for sub_entry in entry:
                                question = sub_entry.get("question")
                                answer = sub_entry.get("answer")
                                if isinstance(answer, list):
                                    answer = " ".join(answer)
                                if question and answer:
                                    self.qa_pairs.append((question, answer))
                        else:
                            self.instructions.append(str(entry))
                elif isinstance(content, dict):
                    for value in content.values():
                        if isinstance(value, dict):
                            question = value.get("question")
                            answer = value.get("answer")
                            if isinstance(answer, list):
                                answer = " ".join(answer)
                            if question and answer:
                                self.qa_pairs.append((question, answer))
                        elif isinstance(value, list):
                            self.instructions.extend(value)
                        else:
                            self.instructions.append(str(value))
            except Exception as e:
                logger.warning(f"Error processing category '{category}': {e}")

        logger.info(f"Loaded {len(self.qa_pairs)} Q&A pairs and {len(self.instructions)} instructions.")


class EmbeddingPipeline:
    def __init__(self, qa_pairs: List[Tuple[str, str]], instructions: List[str]):
        self.qa_pairs = qa_pairs
        self.instructions = instructions
        self.embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
        self.index = None
        self.all_data: List[Tuple[str, str]] = []

    def build_index(self) -> None:
        question_embeddings = self.embedder.encode([q[0] for q in self.qa_pairs])
        instruction_embeddings = self.embedder.encode(self.instructions)

        dimension = question_embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dimension)
        self.index.add(np.array(question_embeddings))
        self.index.add(np.array(instruction_embeddings))
        self.all_data = self.qa_pairs + [(inst, inst) for inst in self.instructions]
        logger.info(f"FAISS index built with {len(self.all_data)} entries.")

    def retrieve(self, query: str, top_k: int = 3, threshold: float = 0.7) -> List[str]:
        try:
            query_embedding = self.embedder.encode([query])
            distances, indices = self.index.search(np.array(query_embedding), top_k)
            return [self.all_data[idx][1] for dist, idx in zip(distances[0], indices[0]) if dist < threshold]
        except Exception as e:
            logger.error(f"Error retrieving answer: {e}")
            return []


class TextGenerator:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("CohereForAI/aya-expanse-8b", timeout=600)
        self.model = AutoModelForCausalLM.from_pretrained("CohereForAI/aya-expanse-8b", load_in_4bit=True, device_map="auto")
        self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)

    def generate(self, prompt: str, min_length: int, max_length: int) -> str:
        try:
            outputs = self.pipeline(
                prompt,
                max_length=max_length,
                min_length=min_length,
                do_sample=True,
                top_p=0.85,
                temperature=0.7,
                eos_token_id=self.tokenizer.eos_token_id
            )
            return outputs[0]['generated_text'].replace(prompt, "").strip()
        except Exception as e:
            logger.error(f"Error generating text: {e}")
            return ""


class ChatbotPipeline:
    def __init__(self, json_path: str):
        data_loader = DataLoader(json_path)
        qa_pairs, instructions = data_loader.load()

        self.user_conversations: Dict[str, List[str]] = {}
        self.embedding = EmbeddingPipeline(qa_pairs, instructions)
        self.embedding.build_index()
        self.generator = TextGenerator()

    def get_dynamic_word_range(self, user_input: str, retrieved_answers: List[str]) -> Tuple[int, int]:
        try:
            language = detect(user_input)
        except:
            language = "en"

        num_retrieved = sum(len(ans.split()) for ans in retrieved_answers)
        avg_len = num_retrieved + len(user_input.split())

        language_factors = {
            "en": 1.0, "ar": 0.8, "fr": 1.1, "zh": 0.5, "es": 1.05,
            "de": 1.0, "ru": 1.1, "it": 1.0, "pt": 1.05, "ja": 0.6, "ko": 0.7, "tr": 0.9
        }
        factor = language_factors.get(language, 1.0)

        base_max_len = int((avg_len + 30) * factor)
        return max(150, base_max_len - 100), max(500, base_max_len + 200)

    def format_response(self, text: str) -> str:
        text = re.sub(r"<\|im_end\|>|<\|endoftext\|>", "", text).strip()
        text = re.sub(r"\n", "", text).strip()
        paragraphs = text.split("\n")
        lines = []
        for paragraph in paragraphs:
            line = ""
            for word in paragraph.split():
                if len(line) + len(word) > 90:
                    lines.append(line.strip())
                    line = ""
                line += f"{word} "
            if line:
                lines.append(line.strip())
        return "\n".join(lines)

    def add_bullets(self, text: str) -> str:
        lines = text.split("\n")
        return "\n".join([f"- {line.strip()}" if re.match(r"^\d+\.|^-", line) else line for line in lines])

    def clear_conclusion(self, text: str) -> str:
        lines = text.split("\n")
        if lines:
            for phrase in ["Conclusion:", "Summary:", "Final Thoughts:"]:
                if lines[-1].startswith(phrase):
                    lines[-1] = lines[-1][len(phrase):].strip()
        return "\n".join(lines)

    def generate_response(self, user_id: str, user_input: str) -> str:
        if user_id not in self.user_conversations:
            self.user_conversations[user_id] = []

        self.user_conversations[user_id].append(f"User: {user_input}")
        retrieved_answers = self.embedding.retrieve(user_input)

        if retrieved_answers and len(retrieved_answers[0].split()) < 50:
            response = retrieved_answers[0]
        else:
            context = " ".join(retrieved_answers[:2]) if retrieved_answers else ""
            min_chars, max_chars = self.get_dynamic_word_range(user_input, retrieved_answers)
            history = "\n".join(self.user_conversations[user_id][-5:])

            prompt = f"""
            You are an AI assistant specialized in answering questions.
            - Use provided context when available.
            - Ensure responses are **detailed but concise**.
            - Use **bullet points** when listing details.
            - Do not repeat information unnecessarily.
            - Maintain a **conversational tone**.

            {'Context: ' + context if context else 'No specific context available.'}

            Conversation history:
            {history}

            User question: {user_input}
            Answer:
            """

            response = self.generator.generate(prompt, min_chars, max_chars)
            response = self.format_response(response)
            response = self.add_bullets(response)
            response = self.clear_conclusion(response)

        if response and response[-1] not in ".!?":
            response += "."

        self.user_conversations[user_id].append(f"Chatbot: {response}")
        return response
    def start_chat(self, user_id: str = "user1") -> None:
        print("Start chatting with the AI! Type 'exit' or 'quit' to stop.\n")
        while True:
            user_input = input("You: ")
            if user_input.lower() in ["exit", "quit"]:
                print("Chat ended.")
                break
            response = self.generate_response(user_id=user_id, user_input=user_input)
            print(f"Bot: {response}\n")

