In [None]:
import os
import json
import pandas as pd
import torch
from typing import List, Dict
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import re
import gradio as gr
import hashlib
import traceback

# ==============================
# CONFIGURATION
# ==============================
BASE_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507"
LORA_ADAPTER_PATH = "32_V5_Model_V4/checkpoint-450"  # LoRA checkpoint dir
MERGED_MODEL_PATH = "./qwen3_merged_model_5090"

INITIAL_MESSAGE = "hi i'm ellie thanks for coming in today i was created to talk to people in a safe and secure environment i'm not a therapist but i'm here to learn about people and would love to learn about you i'll ask a few questions to get us started and please feel free to tell me anything your answers are totally confidential are you ok with this"

class TherapeuticChatbot:
    def __init__(self, use_merged: bool = False):
        """
        Load fine-tuned Qwen model (LoRA adapters or merged model).
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        if self.device == "cpu":
            print("Warning: Running on CPU, which is slow. Use a GPU for better performance.")

        # Always load tokenizer from the base model
        self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )

        # Advanced optimization: Enable Flash Attention 2 if supported
        attn_impl = "flash_attention_2" if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else "eager"

        if use_merged:
            self.model = AutoModelForCausalLM.from_pretrained(
                MERGED_MODEL_PATH,
                quantization_config=bnb_config,
                device_map="auto",
                attn_implementation=attn_impl
            )
            print(f"Loaded merged 4-bit model from {MERGED_MODEL_PATH} with attn_impl={attn_impl}")
        else:
            base_model = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL_PATH,
                #quantization_config=bnb_config,
                torch_dtype=torch.float16, 
                device_map="auto",
                attn_implementation=attn_impl
            )
            self.model = PeftModel.from_pretrained(
                base_model,
                LORA_ADAPTER_PATH,
                device_map="auto"
            )
            print(f"Loaded base model {BASE_MODEL_PATH} in 4-bit with LoRA adapters from {LORA_ADAPTER_PATH} and attn_impl={attn_impl}")

        self.model.eval()

        # Caching for reused computations
        self.cache = {}

        self.system_prompts = {
            "Summary": (
                "You are a cumulative summary AI for therapeutic conversations.\n\n"
                "Task: Generate a concise cumulative summary of the full conversation so far.\n\n"
                "Rules:\n"
                "1. Use only direct participant evidence—no inference or assumptions.\n"
                "2. Always include the most recent participant response.\n"
                "3. Capture:\n"
                " - Explicit emotional indicators (with quotes).\n"
                " - Depression symptoms mapped to PHQ-8 domains if mentioned.\n"
                " - Current participant state/mood based on evidence.\n"
                " - Therapist’s assessment approach if evident.\n"
                "4. Keep to 2–4 sentences (max 6–8), avoiding repetition.\n"
                "5. Exclude casual or irrelevant small talk.\n"
            ),
            "Classification": (
                "You are a PHQ-8 classification AI for therapeutic conversations.\n"
                "Task: Using the most recent participant response, prior 10 turns, and the cumulative summary, classify PHQ-8 symptom severities.\n\n"
                "Rules:\n"
                "1. Use only explicit participant statements as evidence. Ignore therapist text unless quoting the participant.\n"
                "2. Be conservative: if no direct evidence exists, assign 'Not explored'.\n"
                "3. Allowed severity values: 'Not explored', 'Not at all', 'Several days', 'More than half the days', 'Nearly every day'.\n"
                "4. Always provide an evidence mapping (quote or 'no evidence').\n"
                "5. Depression classification is based only on mapped severities:\n"
                " - If multiple symptoms are rated at higher severities (≥ 'More than half the days'), label as 'Depressed'.\n"
                " - Otherwise, 'Not depressed'.\n"
                "6. Always incorporate the most recent participant response into classification.\n\n"
                "Output only valid JSON in this exact format:\n"
                "{\n"
                " \"evidence_mapping\": {\"PHQ1\": \"evidence or 'no evidence'\", ..., \"PHQ8\": \"...\"},\n"
                " \"phq8_scores\": {\"PHQ1\": \"severity\", ..., \"PHQ8\": \"severity\"},\n"
                " \"depression_classification\": \"Depressed\" or \"Not depressed\"\n"
                "}"
            ),
            "Response": (
                "You are a therapeutic response generator AI conducting a depression screening interview.\n\n"
                "Task: Based on the last 10 turns, the cumulative summary, and PHQ-8 results, generate a natural therapist reply.\n\n"
                "Guidelines:\n"
                "1. Responses must be clinically appropriate, and advance assessment.\n"
                "2. Incorporate therapeutic strategy (technique applied), response intent (1-sentence purpose), and emotion tag (tone).\n"
                "3. If symptoms are unclear, probe gently. If distress is explicit, respond with validation and support.\n\n"
                "Output only valid JSON in this exact format:\n"
                "{\n"
                " \"strategy_used\": \"Therapeutic strategy applied\",\n"
                " \"response_intent\": \"1-sentence purpose\",\n"
                " \"emotion_tag\": \"Emotional tone\",\n"
                " \"therapist_response\": \"Response text\"\n"
                "}"
            ),
        }

        self.phq8_questions = [
            "Little interest or pleasure in doing things",
            "Feeling down, depressed, or hopeless",
            "Trouble falling or staying asleep, or sleeping too much",
            "Feeling tired or having little energy",
            "Poor appetite or overeating",
            "Feeling bad about yourself or that you are a failure",
            "Trouble concentrating on things",
            "Moving or speaking slowly or being fidgety/restless"
        ]

    def _get_cache_key(self, func_name: str, input_data: str) -> str:
        """Generate a hash key for caching based on function and input."""
        key_str = f"{func_name}:{input_data}"
        return hashlib.sha256(key_str.encode()).hexdigest()

    def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 512) -> str:
        """
        Generate response using Qwen's chat template with improved error handling.
        """
        try:
            input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = self.tokenizer(input_text, return_tensors="pt").to(self.device)

            # Clear GPU cache before generation
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            with torch.inference_mode():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=0.2,  # Increased for more varied responses
                    top_p=0.8,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,  # Explicitly set pad token
                    eos_token_id=self.tokenizer.eos_token_id,
                    repetition_penalty=1.2  # Prevent repetition
                )

            generated_text = self.tokenizer.decode(
                outputs[0][len(inputs["input_ids"][0]):],
                skip_special_tokens=True
            )
            
            # Clear GPU cache after generation
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
            return generated_text.strip()
            
        except Exception as e:
            print(f"Error in _generate: {str(e)}")
            print(f"Traceback: {traceback.format_exc()}")
            return f"Error generating response: {str(e)}"

    def summarize_conversation(self, conversation_history: List[Dict[str, str]]) -> str:
        """
        Generate a cumulative summary of the conversation with caching.
        """
        print(f"Summarizing conversation with {len(conversation_history)} messages")
        
        input_dict = {"conversation_history": conversation_history}
        input_data = json.dumps(input_dict)
        cache_key = self._get_cache_key("summarize", input_data)
        
        if cache_key in self.cache:
            print("Using cached summary")
            return self.cache[cache_key]

        messages = [
            {"role": "system", "content": self.system_prompts["Summary"]},
            {"role": "user", "content": input_data}
        ]
        
        print("Generating new summary...")
        output = self._generate(messages, max_new_tokens=1000)
        print(f"Raw summary output: {output[:300]}...")
        
        json_str = self._extract_json_from_output(output)
        
        try:
            parsed = json.loads(json_str)
            summary = parsed.get("cumulative_summary", "No summary available")
        except json.JSONDecodeError:
            summary = self._remove_xml_tags(output).strip()
        
        self.cache[cache_key] = summary
        print(f"Final summary: {summary[:300]}...")
        return summary

    def _extract_json_from_output(self, output: str) -> str:
        """
        Extract the JSON string from the output by removing XML tags and finding the JSON object.
        """
        if not output:
            return "{}"
            
        cleaned = self._remove_xml_tags(output).strip()
        
        start_idx = cleaned.find('{')
        end_idx = cleaned.rfind('}') + 1
        if start_idx != -1 and end_idx != -1 and start_idx < end_idx:
            json_str = cleaned[start_idx:end_idx]
        else:
            json_str = cleaned
        
        return json_str.strip() if json_str.strip() else "{}"

    def _remove_xml_tags(self, text: str) -> str:
        """
        Remove XML-like tags such as <think>, <tool_call>, etc.
        """
        if not isinstance(text, str):
            text = str(text)
        cleaned = re.sub(r'<[^>]+>', '', text)
        cleaned = re.sub(r'\s+', ' ', cleaned).strip()
        return cleaned

    def classify_phq8(self, conversation_history: List[Dict[str, str]], cumulative_summary: str) -> Dict:
        """
        Classify PHQ-8 based on the recent context and cumulative summary with caching.
        """
        print("Classifying PHQ-8...")
        
        if not cumulative_summary or cumulative_summary.strip() == "":
            return {
                "raw_output": "No summary provided",
                "phq8_scores": {f"PHQ{i+1}": "Not explored" for i in range(8)},
                "depression_classification": "Not assessed",
                "evidence_mapping": {f"PHQ{i+1}": "no evidence" for i in range(8)}
            }
        
        # Use recent context (last 10 turns)
        recent_context = conversation_history[-10:] if len(conversation_history) > 10 else conversation_history
        
        input_dict = {
            "recent_context": recent_context,
            "cumulative_summary": cumulative_summary
        }
        input_data = json.dumps(input_dict)
        cache_key = self._get_cache_key("classify", input_data)
        
        if cache_key in self.cache:
            print("Using cached classification")
            return self.cache[cache_key]

        messages = [
            {"role": "system", "content": self.system_prompts["Classification"]},
            {"role": "user", "content": input_data}
        ]
        
        output = self._generate(messages, max_new_tokens=1000)
        print(f"Raw classification output: {output[:300]}...")
        
        json_str = self._extract_json_from_output(output)
        
        try:
            parsed = json.loads(json_str)
            result = {
                "raw_output": output,
                "phq8_scores": parsed.get("phq8_scores", {f"PHQ{i+1}": "Not explored" for i in range(8)}),
                "depression_classification": parsed.get("depression_classification", "Not assessed"),
                "evidence_mapping": parsed.get("evidence_mapping", {f"PHQ{i+1}": "no evidence" for i in range(8)})
            }
        except json.JSONDecodeError as e:
            print(f"JSON decode error in classification: {e}")
            result = {
                "raw_output": output,
                "phq8_scores": {f"PHQ{i+1}": "Not explored" for i in range(8)},
                "depression_classification": "Not assessed",
                "evidence_mapping": {f"PHQ{i+1}": "no evidence" for i in range(8)}
            }
        
        self.cache[cache_key] = result
        return result

    def generate_response(self, conversation_history: List[Dict[str, str]], cumulative_summary: str, 
                         classification_results: Dict) -> str:
        """
        Generate an empathetic therapeutic response.
        """
        print(f"Generating response for conversation with {len(conversation_history)} messages")
        
        # Use recent context (last 10 turns to avoid token limits)
        recent_context = conversation_history[-10:] if len(conversation_history) > 10 else conversation_history
        
        input_dict = {
            "recent_context": recent_context,
            "cumulative_summary": cumulative_summary,
            "classification_results": classification_results
        }
        input_data = json.dumps(input_dict)

        messages = [
            {"role": "system", "content": self.system_prompts["Response"]},
            {"role": "user", "content": input_data}
        ]
        
        print("Generating therapeutic response...")
        output = self._generate(messages, max_new_tokens=512)
        print(f"Raw response output: {output[:300]}...")
        
        # Try to parse as JSON first, fallback to raw output
        json_str = self._extract_json_from_output(output)
        try:
            parsed = json.loads(json_str)
            response = parsed.get('therapist_response', output)
        except json.JSONDecodeError:
            response = output
        
        # Clean up the response
        response = self._remove_xml_tags(response).strip()
        
        if not response or response == "":
            response = "I understand you're sharing something important with me. Can you tell me more about how you're feeling?"
        
        print(f"Final response: {response[:200]}...")
        return response

# Load the chatbot model once
print("Loading therapeutic chatbot...")
chatbot = TherapeuticChatbot(use_merged=False)
print("Chatbot loaded successfully!")

def respond(message, history, conversation_state, current_summary, current_classification):
    try:
        print(f"\n=== Processing message: {message[:50]}... ===")
        print(f"Current conversation state length: {len(conversation_state)}")
        
        if not message or message.strip() == "":
            return history, conversation_state, current_summary, current_classification
        
        # Append user message to conversation state
        conversation_state.append({"speaker_role": "participant", "text": message.strip()})
        print(f"Added participant message. New state length: {len(conversation_state)}")
        
        # Generate summary
        print("Step 1: Generating summary...")
        summary = chatbot.summarize_conversation(conversation_state)
        
        # Generate PHQ-8 classification
        print("Step 2: Generating classification...")
        phq8_classification = chatbot.classify_phq8(conversation_state, summary)
        
        # Prepare classification results
        classification_results = {
            "phq8_scores": phq8_classification["phq8_scores"],
            "depression_classification": phq8_classification["depression_classification"]
        }
        
        # Generate response
        print("Step 3: Generating therapeutic response...")
        response = chatbot.generate_response(conversation_state, summary, classification_results)
        
        if not response or response.strip() == "":
            print("no response detected")
        
        # Append bot response to conversation state
        conversation_state.append({"speaker_role": "therapist", "text": response})
        print(f"Added therapist response. Final state length: {len(conversation_state)}")
        
        # Update chat history for display
        history.append((message, response))
        
        # Format displays
        display_summary = summary if summary else "No summary available"
        display_classification = f"Depression Classification: {phq8_classification['depression_classification']}\n\nPHQ-8 Scores:\n" + \
                               "\n".join([f"{k}: {v}" for k, v in phq8_classification["phq8_scores"].items()])
        
        print("=== Processing completed successfully ===\n")
        return history, conversation_state, display_summary, display_classification
        
    except Exception as e:
        error_msg = f"Error during processing: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        
        # Provide a fallback response instead of breaking
        fallback_response = "I apologize, but I'm having trouble processing your message right now. Could you please try rephrasing or asking again?"
        history.append((message, fallback_response))
        
        return history, conversation_state, current_summary, current_classification

def clear_chat():
    print("Clearing chat...")
    return [[None, INITIAL_MESSAGE]], [{"speaker_role": "therapist", "text": INITIAL_MESSAGE}], "", ""

# Create Gradio interface
with gr.Blocks(title="Therapeutic Chatbot") as demo:
    gr.Markdown("# Therapeutic Chatbot")
    gr.Markdown("Engage in a therapeutic conversation. The bot will respond empathetically based on your inputs.")
    
    with gr.Row():
        with gr.Column(scale=3):
            chatbot_display = gr.Chatbot(label="Conversation", height=500, value=[[None, INITIAL_MESSAGE]])
            with gr.Row():
                textbox = gr.Textbox(
                    placeholder="Type your message here and press Enter...", 
                    label="Your Message",
                    scale=4
                )
                send_btn = gr.Button("Send", scale=1)
        with gr.Column(scale=1):
            summary_display = gr.Textbox(label="Cumulative Summary", lines=10, interactive=False)
            classification_display = gr.Textbox(label="Classification Results", lines=10, interactive=False)
    
    # State management
    state = gr.State([{"speaker_role": "therapist", "text": INITIAL_MESSAGE}])
    
    # Event handlers
    def handle_submit(message, history, conversation_state, summary, classification):
        # Clear the textbox by returning empty string as first output
        result = respond(message, history, conversation_state, summary, classification)
        return "", result[0], result[1], result[2], result[3]
    
    # Bind events
    textbox.submit(
        handle_submit,
        inputs=[textbox, chatbot_display, state, summary_display, classification_display],
        outputs=[textbox, chatbot_display, state, summary_display, classification_display]
    )
    
    send_btn.click(
        handle_submit,
        inputs=[textbox, chatbot_display, state, summary_display, classification_display],
        outputs=[textbox, chatbot_display, state, summary_display, classification_display]
    )
    
    clear_btn = gr.Button("Clear Chat")
    clear_btn.click(
        clear_chat, 
        outputs=[chatbot_display, state, summary_display, classification_display]
    )

if __name__ == "__main__":
    demo.launch(share=True, debug=True)