In [None]:
!pip install -q flask flask_cors pyngrok transformers accelerate bitsandbytes shap lime scikit-learn numpy

In [1]:
from flask import Flask, request, jsonify
from flask_cors import CORS
from pyngrok import ngrok
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
"""
Explainable AI (XAI) utilities for the Press Conference Simulator.

Modes:
- "semantic": Fast baseline ‚Äî show which sentences in the speech are most similar to the question.
- "shap":     Token-level attribution over the SPEECH for how much it helps explain the QUESTION.
- "attention":Top tokens by last-layer attention (model must be loaded with output_attentions=True).
- "lime":     Placeholder for classifier models exposing predict_proba (off by default).

Notes:
- No prompt changes are required. This runs post-generation.
- For SHAP and attention, pass the same `model` and `tokenizer` used to generate.
"""

from typing import Dict, Any, List
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# Optional deps
try:
    import shap
except Exception:
    shap = None
try:
    from lime.lime_text import LimeTextExplainer
except Exception:
    LimeTextExplainer = None


# ============== Helpers ==============

def _split_sentences(text: str) -> List[str]:
    parts = [p.strip() for p in text.replace("!", ".").replace("?", ".").split(".")]
    return [p for p in parts if p]

def _semantic_explain(speech: str, question: str) -> str:
    segments = _split_sentences(speech)
    if not segments:
        return "No speech segments to analyze."
    corpus = segments + [question]
    vec = TfidfVectorizer(stop_words="english")
    X = vec.fit_transform(corpus)
    sims = cosine_similarity(X[-1], X[:-1]).ravel()
    if sims.size == 0:
        return "No similarity signal detected."
    order = np.argsort(sims)[::-1][:2]
    tops = [f"‚Ä¢ \"{segments[i]}\" (sim={sims[i]:.2f})" for i in order if sims[i] > 0]
    return "Likely influential speech parts:\n" + ("\n".join(tops) if tops else "No strong matches.")

def _avg_logprob_for_target(model, tokenizer, context: str, target: str, device: str = None) -> float:
    import torch
    model_device = next(model.parameters()).device
    if device is None:
        device = str(model_device)
    with torch.no_grad():
        full = context + ("\n" if context and not context.endswith("\n") else "") + target
        enc = tokenizer(full, return_tensors="pt")
        input_ids = enc["input_ids"].to(model_device)
        attn_mask = enc.get("attention_mask", None)
        if attn_mask is not None:
            attn_mask = attn_mask.to(model_device)
        ctx_ids = tokenizer(context, return_tensors="pt")["input_ids"].to(model_device)
        ctx_len = ctx_ids.shape[1]
        full_len = input_ids.shape[1]
        tgt_len = full_len - ctx_len
        if tgt_len <= 0:
            return float("-inf")
        outputs = model(input_ids=input_ids, attention_mask=attn_mask)
        logits = outputs.logits
        logprobs = logits.log_softmax(dim=-1)
        target_token_ids = input_ids[:, ctx_len:full_len]
        prev_positions = logprobs[:, ctx_len-1:full_len-1, :] if ctx_len > 0 else logprobs[:, :full_len-1, :]
        seq = min(prev_positions.shape[1], target_token_ids.shape[1])
        prev_positions = prev_positions[:, -seq:, :]
        target_token_ids = target_token_ids[:, :seq]
        tok_logprobs = prev_positions.gather(dim=-1, index=target_token_ids.unsqueeze(-1)).squeeze(-1)
        return float(tok_logprobs.mean().item())



def _shap_explain(speech: str, question: str, model, tokenizer) -> str:
    if shap is None:
        return "SHAP not available. Install `shap`."
    try:
        masker = shap.maskers.Text(tokenizer=lambda x: x.split())

        def score_fn(texts: List[str]) -> np.ndarray:
            scores = []
            for t in texts:
                try:
                    val = _avg_logprob_for_target(model, tokenizer, t[:512], question[:256])
                    scores.append(val)
                except Exception:
                    scores.append(-999.0)
            return np.array(scores, dtype=float)

        explainer = shap.Explainer(score_fn, masker)
        shap_values = explainer([speech[:512]])  # limit speech length

        if shap_values.values is None or not hasattr(shap_values, "values"):
            return "SHAP returned no values."

        token_importance = np.abs(shap_values.values).mean(axis=0)[0]
        tokens = shap_values.data[0]
        order = np.argsort(token_importance)[::-1][:8]
        top = [f"{tokens[i]} ({token_importance[i]:.3f})" for i in order if token_importance[i] > 0]
        return "Top influential speech tokens (SHAP): " + (", ".join(top) if top else "no signal")
    except Exception as e:
        return f"SHAP failed ({type(e).__name__}). Fallback:\n" + _semantic_explain(speech, question)




def _attention_explain(text: str, model, tokenizer) -> str:
    import torch

    # Force-enable attention outputs
    if not getattr(model.config, "output_attentions", False):
        model.config.output_attentions = True

    # Encode safely
    enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    enc = {k: v.to(next(model.parameters()).device) for k, v in enc.items()}

    with torch.no_grad():
#        outputs = model(**enc, output_attentions=True)
        outputs = model.forward(**enc, output_attentions=True)


    if not hasattr(outputs, "attentions") or outputs.attentions is None:
        return "Attention not available: model did not return attention tensors."

    # Take last-layer attention tensor
    attentions = outputs.attentions[-1]
    if attentions is None or len(attentions) == 0:
        return "No attention data returned."

    attn = attentions[0]  # [heads, seq, seq]
    mean_attn = attn.mean(0).mean(0).cpu().numpy()  # mean over heads
    tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"][0])
    order = np.argsort(mean_attn)[::-1][:8]
    top = [f"{tokens[i]} ({mean_attn[i]:.3f})" for i in order]
    return "Top attention-weighted tokens: " + ", ".join(top)




def explainability_node(state: Dict[str, Any], model=None, tokenizer=None, mode: str = "semantic") -> Dict[str, Any]:
    speech = state.get("speech", "") or ""
    question = state.get("generated_question", "") or ""

    if not speech and not question:
        state["explanation"] = "Insufficient data for explainability."
        return state

    try:
        # --- Semantic mode ---
        if mode == "semantic":
            state["explanation"] = _semantic_explain(speech, question)

        # --- SHAP mode ---
        elif mode == "shap":
            if model is None or tokenizer is None:
                state["explanation"] = "SHAP requires model/tokenizer."
            else:
                state["explanation"] = _shap_explain(speech, question, model, tokenizer)

        # --- Attention mode ---
        elif mode == "attention":
            if model is None or tokenizer is None:
                state["explanation"] = "Attention requires model/tokenizer."
            else:
                combo = (speech + "\n\nQuestion: " + question).strip()
                state["explanation"] = _attention_explain(combo, model, tokenizer)

        # --- LIME mode ---
        elif mode == "lime":
            if LimeTextExplainer is None:
                state["explanation"] = "LIME not available (package missing). Install `lime`."
                return state

            explainer = LimeTextExplainer()
            try:
                # LIME needs a predict_proba function; we use a mock one for now
                exp = explainer.explain_instance(
                    speech,
                    lambda texts: np.array([[np.random.rand(), np.random.rand()] for _ in texts]),
                    num_features=6
                )
                state["explanation"] = "LIME placeholder output: " + str(exp.as_list()[:3])
            except Exception as e:
                state["explanation"] = f"LIME failed ({type(e).__name__})."

        # --- Unknown mode ---
        else:
            state["explanation"] = "Unknown explainability mode."

    except Exception as e:
        state["explanation"] = f"Explainability error: {type(e).__name__}: {e}"

    return state


In [3]:
# --- 3Ô∏è‚É£ Authenticate ngrok ---
NGROK_AUTH_TOKEN = "34ktbqghT9LZGQRy9kuxtWT3P29_6iusbqrquTMMu12cxwwVQ"   # Replace with your token
!ngrok config add-authtoken {NGROK_AUTH_TOKEN}

# --- 4Ô∏è‚É£ Initialize Flask app ---
app = Flask(__name__)
CORS(app)

# --- 5Ô∏è‚É£ Load model once (GPU recommended) ---
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto" if device == "cuda" else None,
    output_attentions=True
)

# --- 6Ô∏è‚É£ Define endpoints ---
@app.route('/explain', methods=['POST'])
def explain():
    data = request.get_json(force=True)
    speech = data.get("speech", "")
    question = data.get("question", "")
    mode = data.get("mode", "shap")
    print("="*80)
    print(f"üß© Explainability request (mode={mode})")
    state = {"speech": speech, "generated_question": question}
    explained = explainability_node(state, model=model, tokenizer=tokenizer, mode=mode)
    print("‚úÖ Explanation:", explained["explanation"])
    return jsonify({"explanation": explained["explanation"]})

@app.route('/generate', methods=['POST'])
def generate():
    data = request.get_json(force=True)
    prompt = data.get("prompt", "")
    print("="*80)
    print("üì© Received JSON payload:", data)
    messages = [
        {"role": "system", "content": "You are an investigative journalist conducting a live interview."},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(text, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_new_tokens=128, temperature=0.8)
    response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("‚úÖ Model output:", response_text)
    return jsonify({"response": response_text.strip()})

@app.route('/ping', methods=['GET'])
def ping():
    return jsonify({"status": "alive"})

# --- 8Ô∏è‚É£ Create public URL & run ---
public_url = ngrok.connect(5000)
print("‚úÖ Public URL:", public_url.public_url)
print("Use this URL in your local Flask app.")

app.run(port=5000)


Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml                                


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/601 [00:00<?, ?B/s]

The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
2025-11-05 19:58:58.046147: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory f

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.55G [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['output_attentions']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_attentions']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

‚úÖ Public URL: https://unbevelled-articularly-linn.ngrok-free.dev
Use this URL in your local Flask app.
 * Serving Flask app '__main__'
 * Debug mode: off


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


üì© Received JSON payload: {'prompt': '### ROLE PROMPTING\n    You are **Investigative Hawk**, a professional journalist.\n    Your characterization: A relentless fact-checker who probes claims, demands concrete evidence, and exposes inconsistencies without hesitation.\n    \n    ### CONTEXT\n    Press conference topic: **AI in healthcare**\n    Guest type: **CEO**\n    Opening statement:\n    """hello we will launch our new ai model for cancer classification , this will detect the desease before two months of it \n"""\n    \n    Ongoing dialogue transcript:\n    \n\n\n    \n    ### CONSTRAINTS\n    - Ask **only one** concise, sharp, and contextually relevant question.\n    - Stay aligned with your persona‚Äôs tone.\n    - Do **not** answer on behalf of the guest.\n    - Avoid repetition or generic questions.\n    - Base your question strictly on the dialogue and factual context above.\n    \n    ### THINKING (internal, invisible)\n    Before producing your question, briefly reason ab

`sdpa` attention does not support `output_attentions=True` or `head_mask`. Please set your attention to `eager` if you want any of these features.


‚úÖ Model output: You are an investigative journalist conducting a live interview.

### ROLE PROMPTING
    You are **Investigative Hawk**, a professional journalist.
    Your characterization: A relentless fact-checker who probes claims, demands concrete evidence, and exposes inconsistencies without hesitation.
    
    ### CONTEXT
    Press conference topic: **AI in healthcare**
    Guest type: **CEO**
    Opening statement:
    """hello we will launch our new ai model for cancer classification , this will detect the desease before two months of it 
"""
    
    Ongoing dialogue transcript:
    


    
    ### CONSTRAINTS
    - Ask **only one** concise, sharp, and contextually relevant question.
    - Stay aligned with your persona‚Äôs tone.
    - Do **not** answer on behalf of the guest.
    - Avoid repetition or generic questions.
    - Base your question strictly on the dialogue and factual context above.
    
    ### THINKING (internal, invisible)
    Before producing your questio