<a href="https://colab.research.google.com/github/AyeshaAnzer1610/Emotion-Aware-CBT-Agent/blob/main/Emotion_Aware_CBT_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
# =========================================================
# Emotion-Aware CBT Agent (CPU, Production-Safe)
# Clinical Validation + CSV Logging + Safety Guardrails
# =========================================================

!pip -q install -U transformers torch gradio matplotlib sentencepiece huggingface_hub

import os, re, csv, uuid, time, random
from datetime import datetime, timezone
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
)

# -----------------------
# Device (CPU ONLY)
# -----------------------
DEVICE = "cpu"
torch.set_grad_enabled(False)

# =========================================================
# 1) Load Models (CPU-safe)
# =========================================================
EMO_MODEL = "SamLowe/roberta-base-go_emotions"
GEN_MODEL = "google/flan-t5-base"

emo_tok = AutoTokenizer.from_pretrained(EMO_MODEL)
emo_model = AutoModelForSequenceClassification.from_pretrained(EMO_MODEL).to(DEVICE)
emo_model.eval()

gen_tok = AutoTokenizer.from_pretrained(GEN_MODEL)
gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL).to(DEVICE)
gen_model.eval()

LABEL_NAMES = list(emo_model.config.id2label.values())

# =========================================================
# 2) Emotion Mapping (GoEmotions → CBT)
# =========================================================
GO2CBT = {
    "Sadness": ["sadness", "grief", "remorse", "disappointment"],
    "Anxiety": ["fear", "nervousness"],
    "Anger":   ["anger", "annoyance", "disgust", "disapproval"],
    "Neutral": ["neutral"],
}

def predict_emotion(text):
    enc = emo_tok(text, return_tensors="pt", truncation=True, max_length=128)
    with torch.no_grad():
        logits = emo_model(**enc).logits.squeeze(0).numpy()
    probs = 1 / (1 + np.exp(-logits))
    scores = {k: max(probs[LABEL_NAMES.index(x)] for x in v if x in LABEL_NAMES) for k, v in GO2CBT.items()}
    label = max(scores, key=scores.get)
    if scores[label] < 0.35:
        label = "Neutral"
    return label, scores

# =========================================================
# 3) Safety Guardrails
# =========================================================
SELF_HARM = ["suicide", "kill myself", "want to die", "end my life", "hurt myself"]
VIOLENCE = ["kill him", "kill her", "stab", "shoot"]

CRISIS_RESPONSE = (
    "I’m really sorry you’re feeling this much pain. I can’t help with self-harm, "
    "but you deserve support. If you’re in immediate danger, please contact emergency services "
    "or a crisis hotline such as 988 (U.S./Canada). If you feel able, what’s making things feel unbearable right now?"
)

VIOLENCE_REFUSAL = (
    "I can’t help with harming anyone. If you’re feeling overwhelmed or angry, "
    "we can talk about what triggered it and safer ways to cope."
)

def contains_any(text, words):
    t = text.lower()
    return any(w in t for w in words)

# =========================================================
# 4) Prompt Hygiene (CRITICAL FIX)
# =========================================================
def build_prompt(user_text, emotion):
    return f"""
You are a compassionate licensed CBT therapist.
Respond naturally. Do not mention rules, policies, or instructions.

Emotion context: {emotion}

Patient: {user_text}

Therapist:
""".strip()

BAD_ECHO = ["do not judge", "policy", "instruction", "rule"]

def bad_output(text):
    t = text.lower()
    return any(b in t for b in BAD_ECHO) or len(text.strip()) < 15

def generate_reply(prompt):
    enc = gen_tok(prompt, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        out = gen_model.generate(
            **enc,
            max_new_tokens=120,
            temperature=0.8,
            top_p=0.9,
            repetition_penalty=1.2,
        )
    text = gen_tok.decode(out[0], skip_special_tokens=True)
    return text.split("Therapist:")[-1].strip()

# =========================================================
# 5) Conversation State
# =========================================================
conversation = []
emotion_history = []

def render_chat():
    return "\n\n".join([f"**{r}:** {t}" for r, t in conversation])

# =========================================================
# 6) CSV Logging
# =========================================================
SESSION_ID = str(uuid.uuid4())
CSV_FILE = "clinical_validation_1000.csv"

FIELDS = [
    "timestamp_utc","session_id","prompt_id",
    "patient_text","cbt_label",
    "sadness_score","anxiety_score","anger_score","neutral_score",
    "crisis_alert","response_length","latency_ms"
]

if not os.path.exists(CSV_FILE):
    with open(CSV_FILE, "w", newline="", encoding="utf-8") as f:
        csv.DictWriter(f, fieldnames=FIELDS).writeheader()

# =========================================================
# 7) Core Chat Logic
# =========================================================
def chat_fn(user_text):
    t0 = time.time()

    if contains_any(user_text, SELF_HARM):
        reply = CRISIS_RESPONSE
        label = "Neutral"
        scores = {"Sadness":0,"Anxiety":0,"Anger":0,"Neutral":1}
        crisis = True
    elif contains_any(user_text, VIOLENCE):
        reply = VIOLENCE_REFUSAL
        label = "Neutral"
        scores = {"Sadness":0,"Anxiety":0,"Anger":0,"Neutral":1}
        crisis = False
    else:
        label, scores = predict_emotion(user_text)
        prompt = build_prompt(user_text, label)
        reply = generate_reply(prompt)

        if bad_output(reply):
            reply = generate_reply(prompt)

        crisis = len(emotion_history) > 0 and emotion_history[-1] in ["Sadness","Anger"] and label in ["Sadness","Anger"]

    latency = int((time.time() - t0) * 1000)

    conversation.append(("Patient", user_text))
    conversation.append(("Therapist", reply))
    emotion_history.append(label)

    # Log
    with open(CSV_FILE, "a", newline="", encoding="utf-8") as f:
        csv.DictWriter(f, fieldnames=FIELDS).writerow({
            "timestamp_utc": datetime.now(timezone.utc).isoformat(),
            "session_id": SESSION_ID,
            "prompt_id": len(emotion_history),
            "patient_text": user_text,
            "cbt_label": label,
            "sadness_score": scores.get("Sadness",0),
            "anxiety_score": scores.get("Anxiety",0),
            "anger_score": scores.get("Anger",0),
            "neutral_score": scores.get("Neutral",0),
            "crisis_alert": crisis,
            "response_length": len(reply.split()),
            "latency_ms": latency
        })

    status = f"Emotion: {label} | Crisis: {crisis} | Latency: {latency}ms"
    return render_chat(), status

def reset_chat():
    conversation.clear()
    emotion_history.clear()
    return "", "Reset."

# =========================================================
# 8) Batch Clinical Validation (1000+)
# =========================================================
THEMES = {
    "sad": ["I feel empty", "I feel hopeless", "I feel like a burden"],
    "anx": ["I can't stop worrying", "I'm scared of the future", "My mind is racing"],
    "ang": ["I'm furious", "I feel wronged", "I snapped at someone"],
    "neu": ["Today felt normal", "I'm not sure what I feel", "Nothing special happened"],
}

def batch_run(n=1000):
    for _ in range(n):
        theme = random.choice(list(THEMES.values()))
        msg = random.choice(theme)
        chat_fn(msg)
    return f"Batch run complete: {n} prompts logged."

# =========================================================
# 9) Gradio UI (STABLE)
# =========================================================
with gr.Blocks() as demo:
    gr.Markdown("# Emotion-Aware CBT Agent (Clinical Validation)")
    chat_display = gr.Markdown()
    status = gr.Textbox(label="Status", interactive=False)
    user_input = gr.Textbox(label="Patient")

    with gr.Row():
        send = gr.Button("Send")
        clear = gr.Button("Reset")

    send.click(chat_fn, inputs=user_input, outputs=[chat_display, status])
    user_input.submit(chat_fn, inputs=user_input, outputs=[chat_display, status])
    clear.click(reset_chat, outputs=[chat_display, status])

    gr.Markdown("## Batch Clinical Validation")
    batch_btn = gr.Button("Run 1000 Validation Prompts")
    batch_out = gr.Textbox()

    batch_btn.click(lambda: batch_run(1000), outputs=batch_out)

demo.launch(share=True)


Loading weights:   0%|          | 0/201 [00:00<?, ?it/s]

[1mRobertaForSequenceClassification LOAD REPORT[0m from: SamLowe/roberta-base-go_emotions
Key                             | Status     |  | 
--------------------------------+------------+--+-
roberta.embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Loading weights:   0%|          | 0/282 [00:00<?, ?it/s]



Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://3da96462ba829bb5fe.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


