# Imports

In [1]:
!pip install protobuf==3.20.3
!pip install transformers-interpret
!CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu121

Collecting protobuf==3.20.3
  Downloading protobuf-3.20.3-py2.py3-none-any.whl.metadata (720 bytes)
Downloading protobuf-3.20.3-py2.py3-none-any.whl (162 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.1/162.1 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 6.33.0
    Uninstalling protobuf-6.33.0:
      Successfully uninstalled protobuf-6.33.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
opentelemetry-proto 1.37.0 requires protobuf<7.0,>=5.0, but you have protobuf 3.20.3 which is incompatible.
onnx 1.18.0 requires protobuf>=4.25.1, but you have protobuf 3.20.3 which is incompatible.
a2a-sdk 0.3.10 requi

In [2]:
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, AutoConfig
import transformers
import shap
from transformers_interpret import SequenceClassificationExplainer
from llama_cpp import Llama
from huggingface_hub import snapshot_download
import pandas as pd

2025-11-27 01:29:06.565618: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764206946.749254      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764206946.801925      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# Model configurations
MODEL_CONFIGS = {
    "BERT": "/kaggle/input/disorbert-finetuned-models/pytorch/default/1/finetuned-models/bert-base-cased-finetuned/checkpoint-5530",
    "RoBERTa": "/kaggle/input/disorbert-finetuned-models/pytorch/default/1/finetuned-models/roberta-base-finetuned/checkpoint-6320", 
    "MentalBERT": "/kaggle/input/disorbert-finetuned-models/pytorch/default/1/finetuned-models/deberta-v3-base-finetuned/checkpoint-4345",
}

LABEL_NAMES = ["Normal", "Depressed"] 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Global variables
models = {}
tokenizers = {}
pipelines = {}
llm = None

def load_qwen_model():
    """Load Qwen model for explanation generation"""
    global llm
    try:
        print("Downloading Qwen model...")
        repo_id = "Qwen/Qwen2.5-1.5B-Instruct-GGUF" # Using a smaller, faster model for demo
        model_file = "qwen2.5-1.5b-instruct-q4_k_m.gguf"
        
        model_dir = snapshot_download(repo_id)
        
        print("Loading Qwen GGUF model...")
        llm = Llama(
            model_path=f"{model_dir}/{model_file}",
            n_ctx=2048,
            n_gpu_layers=-1, # Offload all to GPU
            verbose=False
        )
        print("Qwen model loaded successfully!")
        return True
    except Exception as e:
        print(f"Error loading Qwen model: {e}")
        return False

def load_classification_models():
    """Load all classification models and tokenizers"""
    for name, path in MODEL_CONFIGS.items():
        try:
            print(f"Loading {name} from {path}...")
            tokenizers[name] = AutoTokenizer.from_pretrained(path)
            
            # FIX: Ensure output_attentions is True in config
            config = AutoConfig.from_pretrained(path)
            config.output_attentions = True
            config.output_hidden_states = True
            config.num_labels = 2 
            
            # Load model with config
            models[name] = AutoModelForSequenceClassification.from_pretrained(
                path, 
                config=config
            ).to(device)
            models[name].eval()
            
            # Create pipeline
            pipelines[name] = pipeline(
                'text-classification', 
                model=models[name], 
                tokenizer=tokenizers[name], 
                device=device,
                top_k=None 
            )
            print(f"✓ {name} loaded successfully")
        except Exception as e:
            print(f"✗ Error loading {name}: {e}")

def initialize_models():
    print("="*50)
    print("Initializing Models...")
    load_classification_models()
    qwen_loaded = load_qwen_model()
    print("="*50)

# --- Visualization Functions ---

def get_attention_visualization(text, model_name):
    """
    Calculates attention statistics per word and generates visualizations.
    Returns: DataFrame, Figure (Total Attention), Figure (Average Attention)
    """
    try:
        model = models[model_name]
        tokenizer = tokenizers[model_name]
        
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_attentions=True)
        
        # Get attention from last layer: [Batch, Heads, Seq_Len, Seq_Len]
        # Average over heads -> [Seq_Len, Seq_Len]
        attention_matrix = outputs.attentions[-1][0].mean(dim=0)
        
        # Calculate 'Attention Received' by summing columns (how much others attended to this token)
        # Alternatively, summing rows gives 'Attention Paid' by this token.
        # 'Received' is often more indicative of importance in classification.
        attention_scores = attention_matrix.sum(dim=0).cpu().numpy()
        
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        
        # Collect raw data
        data = []
        special_tokens = tokenizer.all_special_tokens
        
        for token, score in zip(tokens, attention_scores):
            # Clean up token text (handle BERT ## and RoBERTa Ġ)
            word = token.replace("##", "").replace("Ġ", "").strip()
            
            # Skip special tokens and empty strings
            if token in special_tokens or not word:
                continue
                
            data.append({"Word": word, "Score": score})
            
        df_raw = pd.DataFrame(data)
        
        if df_raw.empty:
            return pd.DataFrame(), None, None

        # Aggregate by Word
        agg_df = df_raw.groupby("Word").agg(
            Total_Attention=("Score", "sum"),
            Frequency=("Score", "count"),
            Average_Attention=("Score", "mean")
        ).reset_index()
        
        # Sort for display
        agg_df = agg_df.sort_values(by="Total_Attention", ascending=False)
        
        # Plot 1: Total Attention
        fig_total, ax1 = plt.subplots(figsize=(10, 6))
        top_total = agg_df.head(15)
        sns.barplot(data=top_total, x="Total_Attention", y="Word", ax=ax1, palette="viridis")
        ax1.set_title(f"Total Attention Mass by Word ({model_name})")
        plt.tight_layout()
        
        # Plot 2: Average Attention
        fig_avg, ax2 = plt.subplots(figsize=(10, 6))
        top_avg = agg_df.sort_values(by="Average_Attention", ascending=False).head(15)
        sns.barplot(data=top_avg, x="Average_Attention", y="Word", ax=ax2, palette="magma")
        ax2.set_title(f"Average Attention per Occurrence ({model_name})")
        plt.tight_layout()

        return agg_df, fig_total, fig_avg

    except Exception as e:
        print(f"Attn Error: {e}")
        return pd.DataFrame(), None, None

def get_integrated_gradients(text, model_name):
    """Generates a bar chart of word attributions."""
    try:
        model = models[model_name]
        tokenizer = tokenizers[model_name]
        
        explainer = SequenceClassificationExplainer(model, tokenizer)
        word_attributions = explainer(text) # Returns list of (word, score)
        
        # Filter special tokens for cleaner plot
        filtered = [x for x in word_attributions if x[0] not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>']]
        # Sort by absolute impact
        filtered.sort(key=lambda x: abs(x[1]), reverse=True)
        top_features = filtered[:10] # Top 10
        
        words = [x[0] for x in top_features]
        scores = [x[1] for x in top_features]
        colors = ['green' if s > 0 else 'red' for s in scores]
        
        fig, ax = plt.subplots(figsize=(10, 6))
        ax.barh(words, scores, color=colors)
        ax.set_xlabel("Attribution Score")
        ax.set_title(f"Top Words Impacting Prediction ({model_name})")
        return fig, top_features
    except Exception as e:
        print(f"IG Error: {e}")
        return None, []

def get_shap_values(text, model_name):
    """Generates a bar chart using SHAP values."""
    try:
        pipe = pipelines[model_name]
        # Use a generic explainer or partition explainer for text
        explainer = shap.Explainer(pipe)
        shap_values = explainer([text])
        
        # Extract values for the first sample
        values = shap_values[0].values
        if len(values.shape) > 1: # If multi-class, take the predicted class or positive class
            values = values[:, 1] # Assuming index 1 is 'Depressed'
            
        data = values
        feature_names = shap_values[0].data
        
        # Create DataFrame for easy plotting
        df = pd.DataFrame({'word': feature_names, 'shap': data})
        df['abs_shap'] = df['shap'].abs()
        df = df.sort_values('abs_shap', ascending=False).head(10)
        
        fig, ax = plt.subplots(figsize=(10, 6))
        colors = ['blue' if x > 0 else 'orange' for x in df['shap']]
        ax.barh(df['word'], df['shap'], color=colors)
        ax.set_title(f"SHAP Feature Importance ({model_name})")
        return fig
    except Exception as e:
        print(f"SHAP Error: {e}")
        return None

def generate_explanation_prompt(text, results, top_tokens):
    """Prepare prompt for Qwen"""
    
    # Format predictions
    preds_str = ""
    for model, res in results.items():
        if 'error' not in res:
            preds_str += f"- {model}: Predicted {res['predicted_class']} (Confidence: {max(res['probabilities'].values()):.2%})\n"
            
    # Format key words
    tokens_str = ", ".join([f"{t[0]} ({t[1]:.2f})" for t in top_tokens[:5]])
    
    prompt = f"""<|im_start|>system
You are an empathetic AI psychological assistant. Analyze the following classification results.
<|im_end|>
<|im_start|>user
Input Text: "{text}"

Model Predictions:
{preds_str}

Key Influential Words (Integrated Gradients): {tokens_str}

Please provide:
1. A summary of the mental state suggested by the text.
2. Why the models likely predicted this (based on the key words).
3. A friendly disclaimer that this is an AI tool and not a doctor.
<|im_end|>
<|im_start|>assistant
"""
    return prompt

def call_qwen_model(prompt):
    if llm is None: 
        return "Error: Qwen model not loaded."
    
    output = llm(
        prompt, 
        max_tokens=512, 
        stop=["<|im_end|>"], 
        echo=False
    )
    return output['choices'][0]['text']

In [5]:
initialize_models()

with gr.Blocks(title="Mental Health Analysis", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Mental Health Text Analysis & Interpretation")
    
    # State storage
    stored_text = gr.State()
    stored_results = gr.State()
    
    with gr.Row():
        with gr.Column(scale=1):
            text_input = gr.Textbox(label="Input Text", lines=5, placeholder="I've been feeling...")
            model_selector = gr.CheckboxGroup(choices=list(MODEL_CONFIGS.keys()), value=list(MODEL_CONFIGS.keys())[:1], label="Select Models")
            analyze_btn = gr.Button("Step 1: Analyze", variant="primary")
        
        with gr.Column(scale=1):
            results_output = gr.Markdown(label="Predictions")
    
    # Visualization Section
    interpret_btn = gr.Button("Step 2: Generate Visualizations", variant="secondary", visible=False)
    
    with gr.Tabs(visible=False) as viz_tabs:
        # Updated Attention Tab with Sub-tabs
        with gr.Tab("Attention Analysis"):
            with gr.Tabs():
                with gr.Tab("Total Attention View"):
                    attn_plot_total = gr.Plot(label="Total Attention by Word")
                with gr.Tab("Average Attention View"):
                    attn_plot_avg = gr.Plot(label="Average Attention by Word")
                with gr.Tab("Data Table"):
                    attn_table = gr.Dataframe(label="Attention Stats")
        
        with gr.Tab("Integrated Gradients"):
            ig_plot = gr.Plot()
        with gr.Tab("SHAP Values"):
            shap_plot = gr.Plot()
            
    # Explanation Section
    explain_btn = gr.Button("Step 3: AI Explanation", variant="secondary", visible=False)
    explanation_output = gr.Markdown(visible=False)

    # --- Callbacks ---

    def step1_analyze(text, selected_models):
        if not text or not selected_models:
            return "Please enter text and select a model.", gr.update(visible=False), text, None
        
        results = {}
        output_md = "### Model Predictions\n"
        
        for name in selected_models:
            if name not in pipelines: continue
            
            try:
                res = pipelines[name](text, return_all_scores=True)
                probs = {item['label']: item['score'] for item in res[0]}
                pred_label = max(probs, key=probs.get)
                
                results[name] = {
                    "predicted_class": pred_label,
                    "probabilities": probs
                }
                
                output_md += f"**{name}**: `{pred_label}`\n"
                for l, p in probs.items():
                    output_md += f"- {l}: {p:.1%}\n"
                output_md += "\n"
            except Exception as e:
                results[name] = {"error": str(e)}
                output_md += f"**{name}**: Error - {e}\n"

        return output_md, gr.update(visible=True), text, results

    def step2_visualize(text, results):
        if not results: return [None]*6
        
        model_name = list(results.keys())[0]
        
        # Get Attention Data and Plots
        attn_df, attn_fig_total, attn_fig_avg = get_attention_visualization(text, model_name)
        
        # Get other visualizations
        ig_fig, top_tokens = get_integrated_gradients(text, model_name)
        shap_fig = get_shap_values(text, model_name)
        
        return (
            attn_fig_total, 
            attn_fig_avg, 
            attn_df,
            ig_fig, 
            shap_fig, 
            gr.update(visible=True), # Show tabs
            gr.update(visible=True)  # Show explain btn
        )

    def step3_explain_llm(text, results):
        model_name = list(results.keys())[0]
        _, top_tokens = get_integrated_gradients(text, model_name)
        
        prompt = generate_explanation_prompt(text, results, top_tokens)
        explanation = call_qwen_model(prompt)
        return gr.update(value=explanation, visible=True)

    # Wiring
    analyze_btn.click(
        step1_analyze, 
        inputs=[text_input, model_selector], 
        outputs=[results_output, interpret_btn, stored_text, stored_results]
    )
    
    interpret_btn.click(
        step2_visualize,
        inputs=[stored_text, stored_results],
        outputs=[attn_plot_total, attn_plot_avg, attn_table, ig_plot, shap_plot, viz_tabs, explain_btn]
    )
    
    explain_btn.click(
        step3_explain_llm,
        inputs=[stored_text, stored_results],
        outputs=[explanation_output]
    )

Initializing Models...
Loading BERT from /kaggle/input/disorbert-finetuned-models/pytorch/default/1/finetuned-models/bert-base-cased-finetuned/checkpoint-5530...


Device set to use cuda


✓ BERT loaded successfully
Loading RoBERTa from /kaggle/input/disorbert-finetuned-models/pytorch/default/1/finetuned-models/roberta-base-finetuned/checkpoint-6320...


Device set to use cuda


✓ RoBERTa loaded successfully
Loading MentalBERT from /kaggle/input/disorbert-finetuned-models/pytorch/default/1/finetuned-models/deberta-v3-base-finetuned/checkpoint-4345...


Device set to use cuda


✓ MentalBERT loaded successfully
Downloading Qwen model...


Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

qwen2.5-1.5b-instruct-fp16.gguf:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

qwen2.5-1.5b-instruct-q3_k_m.gguf:   0%|          | 0.00/924M [00:00<?, ?B/s]

qwen2.5-1.5b-instruct-q4_0.gguf:   0%|          | 0.00/1.07G [00:00<?, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

LICENSE: 0.00B [00:00, ?B/s]

qwen2.5-1.5b-instruct-q4_k_m.gguf:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

qwen2.5-1.5b-instruct-q2_k.gguf:   0%|          | 0.00/753M [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

qwen2.5-1.5b-instruct-q5_0.gguf:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

qwen2.5-1.5b-instruct-q5_k_m.gguf:   0%|          | 0.00/1.29G [00:00<?, ?B/s]

qwen2.5-1.5b-instruct-q6_k.gguf:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

qwen2.5-1.5b-instruct-q8_0.gguf:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

Loading Qwen GGUF model...


llama_context: n_ctx_per_seq (2048) < n_ctx_train (32768) -- the full capacity of the model will not be utilized


Qwen model loaded successfully!


In [6]:
demo.launch(share=True, debug=True)

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://e465a0a5bdb5ea23c5.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


`return_all_scores` is now deprecated,  if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [00:13, 13.73s/it]               


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://e465a0a5bdb5ea23c5.gradio.live


