In [None]:
import os
import json
import pandas as pd
import numpy as np
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from typing import List, Dict
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import re
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns

# ==============================
# CONFIGURATION
# ==============================
BASE_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507"  # Base Qwen model
LORA_ADAPTER_PATH = "32_V5_Model/checkpoint-400"
MERGED_MODEL_PATH = "./qwen3_merged_model_5090"  # Optional merged model dir



class TherapeuticChatbot:
    def __init__(self, use_merged: bool = False):
        """
        Load fine-tuned Qwen model (LoRA adapters or merged model).
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        if self.device == "cpu":
            print("Warning: Running on CPU, which is slow. Use a GPU for better performance.")

        # ============================================================
        # FIXED TOKENIZER LOADING (MATCHING TRAINING CODE EXACTLY)
        # ============================================================
        print(f"Loading tokenizer from: {BASE_MODEL_PATH}")
        self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
        
        # Apply the EXACT SAME tokenizer configuration as in training
        # Set PAD to EOS (standard for Qwen to avoid resizing or using UNK)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        print(f"Set pad_token to eos_token: {self.tokenizer.pad_token}")
        
        # Set padding side to left for Flash Attention compatibility
        self.tokenizer.padding_side = "left"
        print("Set padding_side to 'left' for Flash Attention")
        
        # BOS token is None by design in Qwen3
        self.tokenizer.bos_token = None
        print("BOS token set to None (Qwen3 design)")
        
        # Verify token configuration
        print(f"EOS token: {self.tokenizer.eos_token} (ID: {self.tokenizer.eos_token_id})")
        print(f"PAD token: {self.tokenizer.pad_token} (ID: {self.tokenizer.pad_token_id})")
        print(f"BOS token: {self.tokenizer.bos_token}")

        # ============================================================
        # MODEL LOADING WITH PROPER CONFIG ALIGNMENT
        # ============================================================
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )

        attn_impl = "flash_attention_2" if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else "eager"

        if use_merged:
            self.model = AutoModelForCausalLM.from_pretrained(
                MERGED_MODEL_PATH,
                quantization_config=bnb_config,
                device_map="auto",
                attn_implementation=attn_impl
            )
            print(f"Loaded merged 4-bit model from {MERGED_MODEL_PATH} with attn_impl={attn_impl}")
        else:
            base_model = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL_PATH,
                torch_dtype=torch.float16,
                device_map="auto",
                attn_implementation=attn_impl
            )
            self.model = PeftModel.from_pretrained(
                base_model,
                LORA_ADAPTER_PATH,
                device_map="auto"
            )
            print(f"Loaded base model {BASE_MODEL_PATH} in 4-bit with LoRA adapters from {LORA_ADAPTER_PATH} and attn_impl={attn_impl}")

        # Function to align model config with tokenizer (from training code)
        self._fix_model_config_tokens()
        self.model.eval()

        self.system_prompts = {
            "Summary": (
                "You are a cumulative summary AI for therapeutic conversations.\n\n"
                "Task: Generate a concise cumulative summary of the full conversation so far.\n\n"
                "Rules:\n"
                "1. Use only direct participant evidence—no inference or assumptions.\n"
                "2. Always include the most recent participant response.\n"
                "3. Capture:\n"
                " - Explicit emotional indicators (with quotes).\n"
                " - Depression symptoms mapped to PHQ-8 domains if mentioned.\n"
                " - Current participant state/mood based on evidence.\n"
                " - Therapist's assessment approach if evident.\n"
                "4. Keep to 2–4 sentences (max 6–8), avoiding repetition.\n"
                "5. Exclude casual or irrelevant small talk.\n"
            ),
            "Classification": (
                "You are a PHQ-8 classification AI for therapeutic conversations.\n"
                "Task: Using the most recent participant response, prior 10 turns, and the cumulative summary, classify PHQ-8 symptom severities.\n\n"
                "Rules:\n"
                "1. Use only explicit participant statements as evidence. Ignore therapist text unless quoting the participant.\n"
                "2. Be conservative: if no direct evidence exists, assign 'Not explored'.\n"
                "3. Allowed severity values: 'Not explored', 'Not at all', 'Several days', 'More than half the days', 'Nearly every day'.\n"
                "4. Always provide an evidence mapping (quote or 'no evidence').\n"
                "5. Depression classification is based only on mapped severities:\n"
                " - If multiple symptoms are rated at higher severities (≥ 'More than half the days'), label as 'Depressed'.\n"
                " - Otherwise, 'Not depressed'.\n"
                "6. Always incorporate the most recent participant response into classification.\n\n"
                "Output only valid JSON in this exact format:\n"
                "{\n"
                " \"evidence_mapping\": {\"PHQ1\": \"evidence or 'no evidence'\", ..., \"PHQ8\": \"...\"},\n"
                " \"phq8_scores\": {\"PHQ1\": \"severity\", ..., \"PHQ8\": \"severity\"},\n"
                " \"depression_classification\": \"Depressed\" or \"Not depressed\"\n"
                "}"
            ),
            "Response": (
                "You are a therapeutic response generator AI conducting a depression screening interview.\n\n"
                "Task: Based on the last 10 turns, the cumulative summary, and PHQ-8 results, generate a natural therapist reply.\n\n"
                "Guidelines:\n"
                "1. Responses must be clinically appropriate, and advance assessment.\n"
                "2. Incorporate therapeutic strategy (technique applied), response intent (1-sentence purpose), and emotion tag (tone).\n"
                "3. If symptoms are unclear, probe gently. If distress is explicit, respond with validation and support.\n\n"
                "Output only valid JSON in this exact format:\n"
                "{\n"
                " \"strategy_used\": \"Therapeutic strategy applied\",\n"
                " \"response_intent\": \"1-sentence purpose\",\n"
                " \"emotion_tag\": \"Emotional tone\",\n"
                " \"therapist_response\": \"Response text\"\n"
                "}"
            ),
        }

        self.phq8_questions = [
            "Little interest or pleasure in doing things",
            "Feeling down, depressed, or hopeless",
            "Trouble falling or staying asleep, or sleeping too much",
            "Feeling tired or having little energy",
            "Poor appetite or overeating",
            "Feeling bad about yourself or that you are a failure",
            "Trouble concentrating on things",
            "Moving or speaking slowly or being fidgety/restless"
        ]

    def _fix_model_config_tokens(self):
        """Align model config with tokenizer to prevent warnings"""
        if hasattr(self.model.config, 'pad_token_id'):
            self.model.config.pad_token_id = self.tokenizer.pad_token_id
        if hasattr(self.model.config, 'bos_token_id'):
            self.model.config.bos_token_id = None  # Qwen3 doesn't use BOS
        if hasattr(self.model.config, 'eos_token_id'):
            self.model.config.eos_token_id = self.tokenizer.eos_token_id
        
        # Also fix generation config if it exists
        if hasattr(self.model, 'generation_config'):
            if self.model.generation_config is not None:
                self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
                self.model.generation_config.bos_token_id = None
                self.model.generation_config.eos_token_id = self.tokenizer.eos_token_id
        print("Model config aligned with tokenizer")

    def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 512) -> str:
        input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.tokenizer(input_text, return_tensors="pt").to(self.device)

        with torch.inference_mode():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                temperature=0.0,
                top_p=1.0,
                pad_token_id=self.tokenizer.pad_token_id,  # Explicit pad token
                eos_token_id=self.tokenizer.eos_token_id,  # Explicit eos token
                bos_token_id=None,  # Qwen3 doesn't use BOS
                repetition_penalty=1.2
            )

        generated_text = self.tokenizer.decode(
            outputs[0][len(inputs["input_ids"][0]):],
            skip_special_tokens=True
        )
        return generated_text.strip()

    def summarize_conversation(self, conversation_history: List[Dict[str, str]]) -> str:
        input_data = json.dumps({"conversation_history": conversation_history})
        messages = [
            {"role": "system", "content": self.system_prompts["Summary"]},
            {"role": "user", "content": input_data}
        ]
        output = self._generate(messages, max_new_tokens=1000)
        return self._remove_xml_tags(output).strip()

    def classify_phq8(self, conversation_history: List[Dict[str, str]], cumulative_summary: str) -> Dict:
        recent_history = conversation_history[-10:] if len(conversation_history) > 10 else conversation_history
        input_data = json.dumps({"recent_context": recent_history,"cumulative_summary": cumulative_summary})
        messages = [
            {"role": "system", "content": self.system_prompts["Classification"]},
            {"role": "user", "content": input_data}
        ]
        output = self._generate(messages, max_new_tokens=1000)
        json_str = self._extract_json_from_output(output)
        try:
            parsed = json.loads(json_str)
            result = {
                "phq8_scores": parsed.get("phq8_scores", {f"PHQ{i+1}": "Not explored" for i in range(8)}),
                "depression_classification": parsed.get("depression_classification", "Not depressed"),
                "evidence_mapping": parsed.get("evidence_mapping", {f"PHQ{i+1}": "no evidence" for i in range(8)})
            }
        except json.JSONDecodeError:
            result = {
                "phq8_scores": {f"PHQ{i+1}": "Not explored" for i in range(8)},
                "depression_classification": "Not depressed",
                "evidence_mapping": {f"PHQ{i+1}": "no evidence" for i in range(8)}
            }
        return result

    def _extract_json_from_output(self, output: str) -> str:
        cleaned = self._remove_xml_tags(output).strip()
        start_idx = cleaned.find('{')
        end_idx = cleaned.rfind('}') + 1
        return cleaned[start_idx:end_idx] if start_idx != -1 and end_idx != -1 else cleaned

    def _remove_xml_tags(self, text: str) -> str:
        if not isinstance(text, str):
            text = str(text)
        cleaned = re.sub(r'<[^>]+>', '', text)
        return re.sub(r'\s+', ' ', cleaned).strip()


def parse_transcript(file_path: str) -> List[Dict[str, str]]:
    df = pd.read_csv(file_path)
    df.columns = df.columns.str.lower()
    if 'speaker' not in df.columns or 'value' not in df.columns:
        df['speaker'] = df.iloc[:, 0].str.extract(r'(\w+)', expand=False)
        df['value'] = df.iloc[:, 0].str.extract(r'\w+ (.*)', expand=False)
    history, current_text, current_speaker = [], [], None
    for _, row in df.iterrows():
        role = 'therapist' if 'ellie' in str(row['speaker']).lower() else 'participant'
        if role != current_speaker:
            if current_speaker:
                history.append({'speaker_role': current_speaker, 'text': ' '.join(current_text)})
            current_speaker, current_text = role, [str(row['value'])]
        else:
            current_text.append(str(row['value']))
    if current_speaker:
        history.append({'speaker_role': current_speaker, 'text': ' '.join(current_text)})
    return history


def map_severity_to_number(severity: str) -> int:
    severity_map = {
        "Not explored": 0,
        "Not at all": 0,
        "Several days": 1,
        "More than half the days": 2,
        "Nearly every day": 3
    }
    return severity_map.get(severity, 0)


def evaluate_model(test_folder: str, label_csv: str, use_merged=False):
    chatbot = TherapeuticChatbot(use_merged=use_merged)
    labels_df = pd.read_csv(label_csv)
    labels_df.columns = labels_df.columns.str.strip()
    if 'PHQ_Binary' not in labels_df.columns or 'PHQ_Score' not in labels_df.columns:
        labels_df.rename(columns={'PHQ8_Binary': 'PHQ_Binary', 'PHQ8_Score': 'PHQ_Score'}, inplace=True)

    predictions_binary, true_binary, predictions_total, true_total, results = [], [], [], [], []

    for file_name in os.listdir(test_folder):
        if not file_name.lower().endswith('.csv'):
            continue
        try:
            pid = int(file_name.split('_')[0])
        except ValueError:
            continue
        label_row = labels_df[labels_df['Participant_ID'] == pid]
        if label_row.empty:
            continue

        true_binary_num = label_row['PHQ_Binary'].values[0]
        true_label = "Depressed" if true_binary_num == 1 else "Not depressed"
        true_total_score = label_row['PHQ_Score'].values[0]

        file_path = os.path.join(test_folder, file_name)
        conversation_history = parse_transcript(file_path)

        summary = chatbot.summarize_conversation(conversation_history)
        classification = chatbot.classify_phq8(conversation_history, summary)

        predicted_label = classification['depression_classification']
        predicted_phq_list = [map_severity_to_number(classification['phq8_scores'].get(f"PHQ{i+1}", "Not explored")) for i in range(8)]
        predicted_total_score = sum(predicted_phq_list)

        predictions_binary.append(predicted_label)
        true_binary.append(true_label)
        predictions_total.append(predicted_total_score)
        true_total.append(true_total_score)

        results.append({
            "Participant_ID": pid,
            "Predicted_Binary": 1 if predicted_label == "Depressed" else 0,
            "True_Binary": true_binary_num,
            "Predicted_Score": predicted_total_score,
            "True_Score": true_total_score,
            "Raw_Classification_Output": json.dumps(classification)
        })

    if results:
        pd.DataFrame(results).to_csv("evaluation_results.csv", index=False)

    # Metrics
    pred_binary_num = [1 if p == "Depressed" else 0 for p in predictions_binary]
    true_binary_num = [1 if t == "Depressed" else 0 for t in true_binary]
    accuracy_bin = accuracy_score(true_binary_num, pred_binary_num)
    precision_bin = precision_score(true_binary_num, pred_binary_num, zero_division=0)
    recall_bin = recall_score(true_binary_num, pred_binary_num, zero_division=0)
    f1_bin = f1_score(true_binary_num, pred_binary_num, zero_division=0)
    cm_bin = confusion_matrix(true_binary_num, pred_binary_num)

    total_mae = np.mean(np.abs(np.array(predictions_total) - np.array(true_total)))

    print("\nBinary Depression Classification Results:")
    print(f"Accuracy: {accuracy_bin:.4f}")
    print(f"Precision: {precision_bin:.4f}")
    print(f"Recall: {recall_bin:.4f}")
    print(f"F1 Score: {f1_bin:.4f}")
    print("Confusion Matrix:")
    print(cm_bin)

    print("\nPHQ-8 Total Score Results:")
    print(f"Total PHQ-8 Score MAE: {total_mae:.4f}")

    metrics = {
        "binary_classification": {
            "accuracy": accuracy_bin,
            "precision": precision_bin,
            "recall": recall_bin,
            "f1_score": f1_bin,
            "confusion_matrix": cm_bin.tolist()
        },
        "phq8_score": {"MAE": total_mae}
    }

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    with open(f"evaluation_metrics_{timestamp}.json", "w") as f:
        json.dump(metrics, f, indent=4)

    # --- Visualization Section ---
    plt.figure(figsize=(5,4))
    sns.heatmap(cm_bin, annot=True, fmt="d", cmap="Blues",
                xticklabels=["Not Depressed", "Depressed"],
                yticklabels=["Not Depressed", "Depressed"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Binary Depression Classification Confusion Matrix")
    plt.tight_layout()
    plt.savefig(f"confusion_matrix_{timestamp}.png")
    plt.close()

    plt.figure(figsize=(6,4))
    plt.scatter(true_total, predictions_total, alpha=0.6)
    plt.plot([0, max(true_total)], [0, max(true_total)], color="red", linestyle="--")
    plt.xlabel("True PHQ-8 Score")
    plt.ylabel("Predicted PHQ-8 Score")
    plt.title("PHQ-8 True vs Predicted Scores")
    plt.tight_layout()
    plt.savefig(f"phq8_true_vs_pred_{timestamp}.png")
    plt.close()

    print(f"\nMetrics and plots saved with timestamp {timestamp}")


if __name__ == "__main__":
    evaluate_model(
        test_folder="Extracted_Text_Transcript_DAIC",
        label_csv="full_test_split.csv",
        use_merged=False
    )

In [None]:
import os
import json
import pandas as pd
import numpy as np
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from typing import List, Dict
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import re
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns

# ==============================
# CONFIGURATION
# ==============================
BASE_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507"  # Base Qwen model
LORA_ADAPTER_PATH = "./checkpoint-400"
MERGED_MODEL_PATH = "./qwen3_merged_model_5090"  # Optional merged model dir


class TherapeuticChatbot:
    def __init__(self, use_merged: bool = False):
        """
        Load fine-tuned Qwen model (LoRA adapters or merged model).
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        if self.device == "cpu":
            print("Warning: Running on CPU, which is slow. Use a GPU for better performance.")

        # Always load tokenizer from the base model
        self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )

        attn_impl = "flash_attention_2" if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else "eager"

        if use_merged:
            self.model = AutoModelForCausalLM.from_pretrained(
                MERGED_MODEL_PATH,
                quantization_config=bnb_config,
                device_map="auto",
                attn_implementation=attn_impl
            )
            print(f"Loaded merged 4-bit model from {MERGED_MODEL_PATH} with attn_impl={attn_impl}")
        else:
            base_model = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL_PATH,
                torch_dtype=torch.float16,
                device_map="auto",
                attn_implementation=attn_impl
            )
            self.model = PeftModel.from_pretrained(
                base_model,
                LORA_ADAPTER_PATH,
                device_map="auto"
            )
            print(f"Loaded base model {BASE_MODEL_PATH} in 4-bit with LoRA adapters from {LORA_ADAPTER_PATH} and attn_impl={attn_impl}")

        self.model.eval()

        self.system_prompts = {
            "Summary": (
                "You are a cumulative summary AI for therapeutic conversations.\n\n"
                "Task: Generate a concise cumulative summary of the full conversation so far.\n\n"
                "Rules:\n"
                "1. Use only direct participant evidence—no inference or assumptions.\n"
                "2. Always include the most recent participant response.\n"
                "3. Capture:\n"
                " - Explicit emotional indicators (with quotes).\n"
                " - Depression symptoms mapped to PHQ-8 domains if mentioned.\n"
                " - Current participant state/mood based on evidence.\n"
                " - Therapist’s assessment approach if evident.\n"
                "4. Keep to 2–4 sentences (max 6–8), avoiding repetition.\n"
                "5. Exclude casual or irrelevant small talk.\n"
            ),
            "Classification": (
                "You are a PHQ-8 classification AI for therapeutic conversations.\n"
                "Task: Using the most recent participant response, prior 10 turns, and the cumulative summary, classify PHQ-8 symptom severities.\n\n"
                "Rules:\n"
                "1. Use only explicit participant statements as evidence. Ignore therapist text unless quoting the participant.\n"
                "2. Be conservative: if no direct evidence exists, assign 'Not explored'.\n"
                "3. Allowed severity values: 'Not explored', 'Not at all', 'Several days', 'More than half the days', 'Nearly every day'.\n"
                "4. Always provide an evidence mapping (quote or 'no evidence').\n"
                "5. Depression classification is based only on mapped severities:\n"
                " - If multiple symptoms are rated at higher severities (≥ 'More than half the days'), label as 'Depressed'.\n"
                " - Otherwise, 'Not depressed'.\n"
                "6. Always incorporate the most recent participant response into classification.\n\n"
                "Output only valid JSON in this exact format:\n"
                "{\n"
                " \"evidence_mapping\": {\"PHQ1\": \"evidence or 'no evidence'\", ..., \"PHQ8\": \"...\"},\n"
                " \"phq8_scores\": {\"PHQ1\": \"severity\", ..., \"PHQ8\": \"severity\"},\n"
                " \"depression_classification\": \"Depressed\" or \"Not depressed\"\n"
                "}"
            ),
            "Response": (
                "You are a therapeutic response generator AI conducting a depression screening interview.\n\n"
                "Task: Based on the last 10 turns, the cumulative summary, and PHQ-8 results, generate a natural therapist reply.\n\n"
                "Guidelines:\n"
                "1. Responses must be clinically appropriate, and advance assessment.\n"
                "2. Incorporate therapeutic strategy (technique applied), response intent (1-sentence purpose), and emotion tag (tone).\n"
                "3. If symptoms are unclear, probe gently. If distress is explicit, respond with validation and support.\n\n"
                "Output only valid JSON in this exact format:\n"
                "{\n"
                " \"strategy_used\": \"Therapeutic strategy applied\",\n"
                " \"response_intent\": \"1-sentence purpose\",\n"
                " \"emotion_tag\": \"Emotional tone\",\n"
                " \"therapist_response\": \"Response text\"\n"
                "}"
            ),
        }

        self.phq8_questions = [
            "Little interest or pleasure in doing things",
            "Feeling down, depressed, or hopeless",
            "Trouble falling or staying asleep, or sleeping too much",
            "Feeling tired or having little energy",
            "Poor appetite or overeating",
            "Feeling bad about yourself or that you are a failure",
            "Trouble concentrating on things",
            "Moving or speaking slowly or being fidgety/restless"
        ]

    def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 512) -> str:
        input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.tokenizer(input_text, return_tensors="pt").to(self.device)

        with torch.inference_mode():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                temperature=0.0,
                top_p=1.0,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                repetition_penalty=1.2
            )

        generated_text = self.tokenizer.decode(
            outputs[0][len(inputs["input_ids"][0]):],
            skip_special_tokens=True
        )
        return generated_text.strip()

    def summarize_conversation(self, conversation_history: List[Dict[str, str]]) -> str:
        input_data = json.dumps({"conversation_history": conversation_history})
        messages = [
            {"role": "system", "content": self.system_prompts["Summary"]},
            {"role": "user", "content": input_data}
        ]
        output = self._generate(messages, max_new_tokens=1000)
        return self._remove_xml_tags(output).strip()

    def classify_phq8(self, conversation_history: List[Dict[str, str]], cumulative_summary: str) -> Dict:
        recent_history = conversation_history[-10:] if len(conversation_history) > 10 else conversation_history
        input_data = json.dumps({"recent_context": recent_history,"cumulative_summary": cumulative_summary})
        messages = [
            {"role": "system", "content": self.system_prompts["Classification"]},
            {"role": "user", "content": input_data}
        ]
        output = self._generate(messages, max_new_tokens=1000)
        json_str = self._extract_json_from_output(output)
        try:
            parsed = json.loads(json_str)
            result = {
                "phq8_scores": parsed.get("phq8_scores", {f"PHQ{i+1}": "Not explored" for i in range(8)}),
                "depression_classification": parsed.get("depression_classification", "Not depressed"),
                "evidence_mapping": parsed.get("evidence_mapping", {f"PHQ{i+1}": "no evidence" for i in range(8)})
            }
        except json.JSONDecodeError:
            result = {
                "phq8_scores": {f"PHQ{i+1}": "Not explored" for i in range(8)},
                "depression_classification": "Not depressed",
                "evidence_mapping": {f"PHQ{i+1}": "no evidence" for i in range(8)}
            }
        return result

    def _extract_json_from_output(self, output: str) -> str:
        cleaned = self._remove_xml_tags(output).strip()
        start_idx = cleaned.find('{')
        end_idx = cleaned.rfind('}') + 1
        return cleaned[start_idx:end_idx] if start_idx != -1 and end_idx != -1 else cleaned

    def _remove_xml_tags(self, text: str) -> str:
        if not isinstance(text, str):
            text = str(text)
        cleaned = re.sub(r'<[^>]+>', '', text)
        return re.sub(r'\s+', ' ', cleaned).strip()


def parse_transcript(file_path: str) -> List[Dict[str, str]]:
    df = pd.read_csv(file_path)
    df.columns = df.columns.str.lower()
    if 'speaker' not in df.columns or 'value' not in df.columns:
        df['speaker'] = df.iloc[:, 0].str.extract(r'(\w+)', expand=False)
        df['value'] = df.iloc[:, 0].str.extract(r'\w+ (.*)', expand=False)
    history, current_text, current_speaker = [], [], None
    for _, row in df.iterrows():
        role = 'therapist' if 'ellie' in str(row['speaker']).lower() else 'participant'
        if role != current_speaker:
            if current_speaker:
                history.append({'speaker_role': current_speaker, 'text': ' '.join(current_text)})
            current_speaker, current_text = role, [str(row['value'])]
        else:
            current_text.append(str(row['value']))
    if current_speaker:
        history.append({'speaker_role': current_speaker, 'text': ' '.join(current_text)})
    return history


def map_severity_to_number(severity: str) -> int:
    severity_map = {
        "Not explored": 0,
        "Not at all": 0,
        "Several days": 1,
        "More than half the days": 2,
        "Nearly every day": 3
    }
    return severity_map.get(severity, 0)


def evaluate_model(test_folder: str, label_csv: str, use_merged=False):
    chatbot = TherapeuticChatbot(use_merged=use_merged)
    labels_df = pd.read_csv(label_csv)
    labels_df.columns = labels_df.columns.str.strip()
    if 'PHQ_Binary' not in labels_df.columns or 'PHQ_Score' not in labels_df.columns:
        labels_df.rename(columns={'PHQ8_Binary': 'PHQ_Binary', 'PHQ8_Score': 'PHQ_Score'}, inplace=True)

    predictions_binary, true_binary, predictions_total, true_total, results = [], [], [], [], []

    for file_name in os.listdir(test_folder):
        if not file_name.lower().endswith('.csv'):
            continue
        try:
            pid = int(file_name.split('_')[0])
        except ValueError:
            continue
        label_row = labels_df[labels_df['Participant_ID'] == pid]
        if label_row.empty:
            continue

        true_binary_num = label_row['PHQ_Binary'].values[0]
        true_label = "Depressed" if true_binary_num == 1 else "Not depressed"
        true_total_score = label_row['PHQ_Score'].values[0]

        file_path = os.path.join(test_folder, file_name)
        conversation_history = parse_transcript(file_path)

        summary = chatbot.summarize_conversation(conversation_history)
        classification = chatbot.classify_phq8(conversation_history, summary)

        predicted_label = classification['depression_classification']
        predicted_phq_list = [map_severity_to_number(classification['phq8_scores'].get(f"PHQ{i+1}", "Not explored")) for i in range(8)]
        predicted_total_score = sum(predicted_phq_list)

        predictions_binary.append(predicted_label)
        true_binary.append(true_label)
        predictions_total.append(predicted_total_score)
        true_total.append(true_total_score)

        results.append({
            "Participant_ID": pid,
            "Predicted_Binary": 1 if predicted_label == "Depressed" else 0,
            "True_Binary": true_binary_num,
            "Predicted_Score": predicted_total_score,
            "True_Score": true_total_score,
            "Raw_Classification_Output": json.dumps(classification)
        })

    if results:
        pd.DataFrame(results).to_csv("evaluation_results.csv", index=False)

    # Metrics
    pred_binary_num = [1 if p == "Depressed" else 0 for p in predictions_binary]
    true_binary_num = [1 if t == "Depressed" else 0 for t in true_binary]
    accuracy_bin = accuracy_score(true_binary_num, pred_binary_num)
    precision_bin = precision_score(true_binary_num, pred_binary_num, zero_division=0)
    recall_bin = recall_score(true_binary_num, pred_binary_num, zero_division=0)
    f1_bin = f1_score(true_binary_num, pred_binary_num, zero_division=0)
    cm_bin = confusion_matrix(true_binary_num, pred_binary_num)

    total_mae = np.mean(np.abs(np.array(predictions_total) - np.array(true_total)))

    print("\nBinary Depression Classification Results:")
    print(f"Accuracy: {accuracy_bin:.4f}")
    print(f"Precision: {precision_bin:.4f}")
    print(f"Recall: {recall_bin:.4f}")
    print(f"F1 Score: {f1_bin:.4f}")
    print("Confusion Matrix:")
    print(cm_bin)

    print("\nPHQ-8 Total Score Results:")
    print(f"Total PHQ-8 Score MAE: {total_mae:.4f}")

    metrics = {
        "binary_classification": {
            "accuracy": accuracy_bin,
            "precision": precision_bin,
            "recall": recall_bin,
            "f1_score": f1_bin,
            "confusion_matrix": cm_bin.tolist()
        },
        "phq8_score": {"MAE": total_mae}
    }

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    with open(f"evaluation_metrics_{timestamp}.json", "w") as f:
        json.dump(metrics, f, indent=4)

    # --- Visualization Section ---
    plt.figure(figsize=(5,4))
    sns.heatmap(cm_bin, annot=True, fmt="d", cmap="Blues",
                xticklabels=["Not Depressed", "Depressed"],
                yticklabels=["Not Depressed", "Depressed"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Binary Depression Classification Confusion Matrix")
    plt.tight_layout()
    plt.savefig(f"confusion_matrix_{timestamp}.png")
    plt.close()

    plt.figure(figsize=(6,4))
    plt.scatter(true_total, predictions_total, alpha=0.6)
    plt.plot([0, max(true_total)], [0, max(true_total)], color="red", linestyle="--")
    plt.xlabel("True PHQ-8 Score")
    plt.ylabel("Predicted PHQ-8 Score")
    plt.title("PHQ-8 True vs Predicted Scores")
    plt.tight_layout()
    plt.savefig(f"phq8_true_vs_pred_{timestamp}.png")
    plt.close()

    print(f"\nMetrics and plots saved with timestamp {timestamp}")


if __name__ == "__main__":
    evaluate_model(
        test_folder="Extracted_Text_Transcript_DAIC",
        label_csv="full_test_split.csv",
        use_merged=False
    )


In [None]:
import os
import json
import pandas as pd
import numpy as np
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from typing import List, Dict
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import re
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# ==============================
# CONFIGURATION
# ==============================
BASE_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507"  # Base Qwen model
LORA_ADAPTER_PATH = "32_V5_Model/checkpoint-400"
MERGED_MODEL_PATH = "./qwen3_merged_model_5090"  # Optional merged model dir



class TherapeuticChatbot:
    def __init__(self, use_merged: bool = False):
        """
        Load fine-tuned Qwen model (LoRA adapters or merged model).
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        if self.device == "cpu":
            print("Warning: Running on CPU, which is slow. Use a GPU for better performance.")

        # ============================================================
        # FIXED TOKENIZER LOADING (MATCHING TRAINING CODE EXACTLY)
        # ============================================================
        print(f"Loading tokenizer from: {BASE_MODEL_PATH}")
        self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
        
        # Apply the EXACT SAME tokenizer configuration as in training
        # Set PAD to EOS (standard for Qwen to avoid resizing or using UNK)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        print(f"Set pad_token to eos_token: {self.tokenizer.pad_token}")
        
        # Set padding side to left for Flash Attention compatibility
        self.tokenizer.padding_side = "left"
        print("Set padding_side to 'left' for Flash Attention")
        
        # BOS token is None by design in Qwen3
        self.tokenizer.bos_token = None
        print("BOS token set to None (Qwen3 design)")
        
        # Verify token configuration
        print(f"EOS token: {self.tokenizer.eos_token} (ID: {self.tokenizer.eos_token_id})")
        print(f"PAD token: {self.tokenizer.pad_token} (ID: {self.tokenizer.pad_token_id})")
        print(f"BOS token: {self.tokenizer.bos_token}")

        # ============================================================
        # MODEL LOADING WITH PROPER CONFIG ALIGNMENT
        # ============================================================
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )

        attn_impl = "flash_attention_2" if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else "eager"

        if use_merged:
            self.model = AutoModelForCausalLM.from_pretrained(
                MERGED_MODEL_PATH,
                quantization_config=bnb_config,
                device_map="auto",
                attn_implementation=attn_impl
            )
            print(f"Loaded merged 4-bit model from {MERGED_MODEL_PATH} with attn_impl={attn_impl}")
        else:
            base_model = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL_PATH,
                torch_dtype=torch.float16,
                device_map="auto",
                attn_implementation=attn_impl
            )
            self.model = PeftModel.from_pretrained(
                base_model,
                LORA_ADAPTER_PATH,
                device_map="auto"
            )
            print(f"Loaded base model {BASE_MODEL_PATH} in 4-bit with LoRA adapters from {LORA_ADAPTER_PATH} and attn_impl={attn_impl}")

        # Function to align model config with tokenizer (from training code)
        self._fix_model_config_tokens()
        self.model.eval()

        self.system_prompts = {
            "Summary": (
                "You are a cumulative summary AI for therapeutic conversations.\n\n"
                "Task: Generate a concise cumulative summary of the full conversation so far.\n\n"
                "Rules:\n"
                "1. Use only direct participant evidence—no inference or assumptions.\n"
                "2. Always include the most recent participant response.\n"
                "3. Capture:\n"
                " - Explicit emotional indicators (with quotes).\n"
                " - Depression symptoms mapped to PHQ-8 domains if mentioned.\n"
                " - Current participant state/mood based on evidence.\n"
                " - Therapist's assessment approach if evident.\n"
                "4. Keep to 2–4 sentences (max 6–8), avoiding repetition.\n"
                "5. Exclude casual or irrelevant small talk.\n"
            ),
            "Classification": (
                "You are a PHQ-8 classification AI for therapeutic conversations.\n"
                "Task: Using the most recent participant response, prior 10 turns, and the cumulative summary, classify PHQ-8 symptom severities.\n\n"
                "Rules:\n"
                "1. Use only explicit participant statements as evidence. Ignore therapist text unless quoting the participant.\n"
                "2. Be conservative: if no direct evidence exists, assign 'Not explored'.\n"
                "3. Allowed severity values: 'Not explored', 'Not at all', 'Several days', 'More than half the days', 'Nearly every day'.\n"
                "4. Always provide an evidence mapping (quote or 'no evidence').\n"
                "5. Depression classification is based only on mapped severities:\n"
                " - If multiple symptoms are rated at higher severities (≥ 'More than half the days'), label as 'Depressed'.\n"
                " - Otherwise, 'Not depressed'.\n"
                "6. Always incorporate the most recent participant response into classification.\n\n"
                "Output only valid JSON in this exact format:\n"
                "{\n"
                " \"evidence_mapping\": {\"PHQ1\": \"evidence or 'no evidence'\", ..., \"PHQ8\": \"...\"},\n"
                " \"phq8_scores\": {\"PHQ1\": \"severity\", ..., \"PHQ8\": \"severity\"},\n"
                " \"depression_classification\": \"Depressed\" or \"Not depressed\"\n"
                "}"
            ),
            "Response": (
                "You are a therapeutic response generator AI conducting a depression screening interview.\n\n"
                "Task: Based on the last 10 turns, the cumulative summary, and PHQ-8 results, generate a natural therapist reply.\n\n"
                "Guidelines:\n"
                "1. Responses must be clinically appropriate, and advance assessment.\n"
                "2. Incorporate therapeutic strategy (technique applied), response intent (1-sentence purpose), and emotion tag (tone).\n"
                "3. If symptoms are unclear, probe gently. If distress is explicit, respond with validation and support.\n\n"
                "Output only valid JSON in this exact format:\n"
                "{\n"
                " \"strategy_used\": \"Therapeutic strategy applied\",\n"
                " \"response_intent\": \"1-sentence purpose\",\n"
                " \"emotion_tag\": \"Emotional tone\",\n"
                " \"therapist_response\": \"Response text\"\n"
                "}"
            ),
        }

        self.phq8_questions = [
            "Little interest or pleasure in doing things",
            "Feeling down, depressed, or hopeless",
            "Trouble falling or staying asleep, or sleeping too much",
            "Feeling tired or having little energy",
            "Poor appetite or overeating",
            "Feeling bad about yourself or that you are a failure",
            "Trouble concentrating on things",
            "Moving or speaking slowly or being fidgety/restless"
        ]

    def _fix_model_config_tokens(self):
        """Align model config with tokenizer to prevent warnings"""
        if hasattr(self.model.config, 'pad_token_id'):
            self.model.config.pad_token_id = self.tokenizer.pad_token_id
        if hasattr(self.model.config, 'bos_token_id'):
            self.model.config.bos_token_id = None  # Qwen3 doesn't use BOS
        if hasattr(self.model.config, 'eos_token_id'):
            self.model.config.eos_token_id = self.tokenizer.eos_token_id
        
        # Also fix generation config if it exists
        if hasattr(self.model, 'generation_config'):
            if self.model.generation_config is not None:
                self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
                self.model.generation_config.bos_token_id = None
                self.model.generation_config.eos_token_id = self.tokenizer.eos_token_id
        print("Model config aligned with tokenizer")

    def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 512) -> str:
        input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.tokenizer(input_text, return_tensors="pt").to(self.device)

        with torch.inference_mode():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                temperature=0.0,
                top_p=1.0,
                pad_token_id=self.tokenizer.pad_token_id,  # Explicit pad token
                eos_token_id=self.tokenizer.eos_token_id,  # Explicit eos token
                bos_token_id=None,  # Qwen3 doesn't use BOS
                repetition_penalty=1.2
            )

        generated_text = self.tokenizer.decode(
            outputs[0][len(inputs["input_ids"][0]):],
            skip_special_tokens=True
        )
        return generated_text.strip()

    def summarize_conversation(self, conversation_history: List[Dict[str, str]]) -> str:
        input_data = json.dumps({"conversation_history": conversation_history})
        messages = [
            {"role": "system", "content": self.system_prompts["Summary"]},
            {"role": "user", "content": input_data}
        ]
        output = self._generate(messages, max_new_tokens=1000)
        return self._remove_xml_tags(output).strip()

    def classify_phq8(self, conversation_history: List[Dict[str, str]], cumulative_summary: str) -> Dict:
        recent_history = conversation_history[-10:] if len(conversation_history) > 10 else conversation_history
        input_data = json.dumps({"recent_context": recent_history,"cumulative_summary": cumulative_summary})
        messages = [
            {"role": "system", "content": self.system_prompts["Classification"]},
            {"role": "user", "content": input_data}
        ]
        output = self._generate(messages, max_new_tokens=1000)
        json_str = self._extract_json_from_output(output)
        try:
            parsed = json.loads(json_str)
            result = {
                "phq8_scores": parsed.get("phq8_scores", {f"PHQ{i+1}": "Not explored" for i in range(8)}),
                "depression_classification": parsed.get("depression_classification", "Not depressed"),
                "evidence_mapping": parsed.get("evidence_mapping", {f"PHQ{i+1}": "no evidence" for i in range(8)})
            }
        except json.JSONDecodeError:
            result = {
                "phq8_scores": {f"PHQ{i+1}": "Not explored" for i in range(8)},
                "depression_classification": "Not depressed",
                "evidence_mapping": {f"PHQ{i+1}": "no evidence" for i in range(8)}
            }
        return result

    def _extract_json_from_output(self, output: str) -> str:
        cleaned = self._remove_xml_tags(output).strip()
        start_idx = cleaned.find('{')
        end_idx = cleaned.rfind('}') + 1
        return cleaned[start_idx:end_idx] if start_idx != -1 and end_idx != -1 else cleaned

    def _remove_xml_tags(self, text: str) -> str:
        if not isinstance(text, str):
            text = str(text)
        cleaned = re.sub(r'<[^>]+>', '', text)
        return re.sub(r'\s+', ' ', cleaned).strip()


def parse_transcript(file_path: str) -> List[Dict[str, str]]:
    df = pd.read_csv(file_path)
    df.columns = df.columns.str.lower()
    if 'speaker' not in df.columns or 'value' not in df.columns:
        df['speaker'] = df.iloc[:, 0].str.extract(r'(\w+)', expand=False)
        df['value'] = df.iloc[:, 0].str.extract(r'\w+ (.*)', expand=False)
    history, current_text, current_speaker = [], [], None
    for _, row in df.iterrows():
        role = 'therapist' if 'ellie' in str(row['speaker']).lower() else 'participant'
        if role != current_speaker:
            if current_speaker:
                history.append({'speaker_role': current_speaker, 'text': ' '.join(current_text)})
            current_speaker, current_text = role, [str(row['value'])]
        else:
            current_text.append(str(row['value']))
    if current_speaker:
        history.append({'speaker_role': current_speaker, 'text': ' '.join(current_text)})
    return history


def map_severity_to_number(severity: str) -> int:
    severity_map = {
        "Not explored": 0,
        "Not at all": 0,
        "Several days": 1,
        "More than half the days": 2,
        "Nearly every day": 3
    }
    return severity_map.get(severity, 0)


def evaluate_model(test_folder: str, label_csv: str, use_merged=False):
    chatbot = TherapeuticChatbot(use_merged=use_merged)
    labels_df = pd.read_csv(label_csv)
    labels_df.columns = labels_df.columns.str.strip()
    if 'PHQ_Binary' not in labels_df.columns or 'PHQ_Score' not in labels_df.columns:
        labels_df.rename(columns={'PHQ8_Binary': 'PHQ_Binary', 'PHQ8_Score': 'PHQ_Score'}, inplace=True)

    predictions_binary, true_binary, predictions_total, true_total, results = [], [], [], [], []

    csv_files = [f for f in os.listdir(test_folder) if f.lower().endswith('.csv')]
    for file_name in tqdm(csv_files, desc="Processing transcripts"):
        try:
            pid = int(file_name.split('_')[0])
        except ValueError:
            continue
        label_row = labels_df[labels_df['Participant_ID'] == pid]
        if label_row.empty:
            continue

        true_binary_num = label_row['PHQ_Binary'].values[0]
        true_label = "Depressed" if true_binary_num == 1 else "Not depressed"
        true_total_score = label_row['PHQ_Score'].values[0]

        file_path = os.path.join(test_folder, file_name)
        conversation_history = parse_transcript(file_path)

        summary = chatbot.summarize_conversation(conversation_history)
        classification = chatbot.classify_phq8(conversation_history, summary)

        predicted_label = classification['depression_classification']
        predicted_phq_list = [map_severity_to_number(classification['phq8_scores'].get(f"PHQ{i+1}", "Not explored")) for i in range(8)]
        predicted_total_score = sum(predicted_phq_list)

        predictions_binary.append(predicted_label)
        true_binary.append(true_label)
        predictions_total.append(predicted_total_score)
        true_total.append(true_total_score)

        results.append({
            "Participant_ID": pid,
            "Predicted_Binary": 1 if predicted_label == "Depressed" else 0,
            "True_Binary": true_binary_num,
            "Predicted_Score": predicted_total_score,
            "True_Score": true_total_score,
            "Raw_Classification_Output": json.dumps(classification)
        })

    if results:
        pd.DataFrame(results).to_csv("evaluation_results.csv", index=False)

    # Metrics
    pred_binary_num = [1 if p == "Depressed" else 0 for p in predictions_binary]
    true_binary_num = [1 if t == "Depressed" else 0 for t in true_binary]
    accuracy_bin = accuracy_score(true_binary_num, pred_binary_num)
    precision_bin = precision_score(true_binary_num, pred_binary_num, zero_division=0)
    recall_bin = recall_score(true_binary_num, pred_binary_num, zero_division=0)
    f1_bin = f1_score(true_binary_num, pred_binary_num, zero_division=0)
    cm_bin = confusion_matrix(true_binary_num, pred_binary_num)

    total_mae = np.mean(np.abs(np.array(predictions_total) - np.array(true_total)))

    print("\nBinary Depression Classification Results:")
    print(f"Accuracy: {accuracy_bin:.4f}")
    print(f"Precision: {precision_bin:.4f}")
    print(f"Recall: {recall_bin:.4f}")
    print(f"F1 Score: {f1_bin:.4f}")
    print("Confusion Matrix:")
    print(cm_bin)

    print("\nPHQ-8 Total Score Results:")
    print(f"Total PHQ-8 Score MAE: {total_mae:.4f}")

    metrics = {
        "binary_classification": {
            "accuracy": accuracy_bin,
            "precision": precision_bin,
            "recall": recall_bin,
            "f1_score": f1_bin,
            "confusion_matrix": cm_bin.tolist()
        },
        "phq8_score": {"MAE": total_mae}
    }

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    with open(f"evaluation_metrics_{timestamp}.json", "w") as f:
        json.dump(metrics, f, indent=4)

    # --- Visualization Section ---
    plt.figure(figsize=(5,4))
    sns.heatmap(cm_bin, annot=True, fmt="d", cmap="Blues",
                xticklabels=["Not Depressed", "Depressed"],
                yticklabels=["Not Depressed", "Depressed"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Binary Depression Classification Confusion Matrix")
    plt.tight_layout()
    plt.savefig(f"confusion_matrix_{timestamp}.png")
    plt.close()

    plt.figure(figsize=(6,4))
    plt.scatter(true_total, predictions_total, alpha=0.6)
    plt.plot([0, max(true_total)], [0, max(true_total)], color="red", linestyle="--")
    plt.xlabel("True PHQ-8 Score")
    plt.ylabel("Predicted PHQ-8 Score")
    plt.title("PHQ-8 True vs Predicted Scores")
    plt.tight_layout()
    plt.savefig(f"phq8_true_vs_pred_{timestamp}.png")
    plt.close()

    print(f"\nMetrics and plots saved with timestamp {timestamp}")


if __name__ == "__main__":
    evaluate_model(
        test_folder="Extracted_Text_Transcript_DAIC",
        label_csv="full_test_split.csv",
        use_merged=False
    )