In [6]:
!pip install shap transformers torch langchain langchain-huggingface langchain-groq lime



In [7]:

from langchain_groq import ChatGroq
import getpass
import os

from google.colab import userdata

os.environ["GROQ_API_KEY"] = userdata.get('key1')



In [8]:
#llama3-8b-8192
#mixtral-8x7b-32768
#gemma-7b-it
#llama3-70b-8192
llm = ChatGroq(
    model="llama3-70b-8192",
    temperature=0,
    max_tokens=500,  # Limit the number of tokens to avoid explanations
    timeout=None,
    max_retries=2)

llm

ChatGroq(client=<groq.resources.chat.completions.Completions object at 0x7d189e47ad90>, async_client=<groq.resources.chat.completions.AsyncCompletions object at 0x7d189e332d90>, model_name='llama3-70b-8192', temperature=1e-08, model_kwargs={}, groq_api_key=SecretStr('**********'), max_tokens=500)

#shap

In [9]:
import shap
import numpy as np
import pandas as pd
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate


input_text = """[Patient:
Doctor, my mood has been all over the place. Sometimes, I feel extremely energetic, like I can do anything, and I don’t even need much sleep. My mind races with so many ideas, and I feel unstoppable. But then, out of nowhere, I crash.

During these lows, I feel completely drained, hopeless, and lose interest in everything. Even getting out of bed feels like a struggle. My sleep patterns are all over the place—sometimes I barely sleep, and other times I sleep too much.

I’ve also been making impulsive decisions when I’m feeling high, like spending too much money or talking too fast without thinking. But when the low hits, I regret everything and feel like I’m stuck in a deep hole. I don’t understand why this keeps happening. Could there be something wrong with me?
]"""

# Define the LLM prompt for disease prediction
prompt = PromptTemplate.from_template(
    """
    Predict the disease name from the statement: {question}. Provide only the main disease in this format: **disease name**.

    Guidelines:
    - If no disease is mentioned, return **"No disease detected"**.
    - For multiple diseases, choose the most relevant or primary one.
    - Ignore non-disease information such as symptoms or conditions without a direct disease name.
    - The response should be strictly between **1 to 5 words**, with no additional information.
    """
)


In [10]:
# Initialize LLM chain
chain = LLMChain(llm=llm, prompt=prompt)

# Function to get disease prediction from LLM
def mistral_predict(text):
    response = chain.invoke({"question": text})
    return response.get("text", "").strip()

# Function to extract disease name from LLM response
def extract_disease(output_text):
    output_text = output_text.strip("**")  # Remove potential markdown formatting
    if output_text.lower() == "no disease detected":
        return "No Disease"
    return output_text



  chain = LLMChain(llm=llm, prompt=prompt)


In [11]:
# Function to generate probability outputs for SHAP
def mistral_predict_proba_shap(texts):
    predictions = []
    for text in texts:
        predicted_disease = extract_disease(mistral_predict(text))

        # Assign probabilities
        if predicted_disease == "No Disease":
            prob = [0.5, 0.5]  # Neutral probability if no disease detected
        else:
            prob = [0.8, 0.2]  # Higher confidence if a disease is detected

        predictions.append(prob[0])  # SHAP expects single output probability

    return np.array(predictions)


In [12]:
import shap
import numpy as np

# Define SHAP explainer with a Text masker
text_masker = shap.maskers.Text()
explainer = shap.Explainer(mistral_predict_proba_shap, text_masker)

# Compute SHAP values
sample_texts = [input_text]   # Select 5 texts
shap_values = explainer(sample_texts)

# Visualize SHAP explanations
shap.text_plot(shap_values)  # Use text_plot instead of summary_plot


  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [18:41, 1121.11s/it]             


In [13]:
# sample_texts = input_texts[:5]  # First 5 samples

for text in sample_texts:
    prediction = mistral_predict(text)
    print(f"🔹 Input: {text}\n🔸 LLM Prediction: {prediction}\n")


🔹 Input: [Patient:
Doctor, my mood has been all over the place. Sometimes, I feel extremely energetic, like I can do anything, and I don’t even need much sleep. My mind races with so many ideas, and I feel unstoppable. But then, out of nowhere, I crash.

During these lows, I feel completely drained, hopeless, and lose interest in everything. Even getting out of bed feels like a struggle. My sleep patterns are all over the place—sometimes I barely sleep, and other times I sleep too much.

I’ve also been making impulsive decisions when I’m feeling high, like spending too much money or talking too fast without thinking. But when the low hits, I regret everything and feel like I’m stuck in a deep hole. I don’t understand why this keeps happening. Could there be something wrong with me?
]
🔸 LLM Prediction: **Bipolar Disorder**



In [14]:
# Print SHAP values
for i in range(len(sample_texts)):
    print(f"🔹 Input: {sample_texts[i]}")
    print(f"🔸 SHAP Values: {shap_values[i].values}")  # Check values per token
    print(f"🔹 Feature Importance: {shap_values[i].data}")  # Check tokens


🔹 Input: [Patient:
Doctor, my mood has been all over the place. Sometimes, I feel extremely energetic, like I can do anything, and I don’t even need much sleep. My mind races with so many ideas, and I feel unstoppable. But then, out of nowhere, I crash.

During these lows, I feel completely drained, hopeless, and lose interest in everything. Even getting out of bed feels like a struggle. My sleep patterns are all over the place—sometimes I barely sleep, and other times I sleep too much.

I’ve also been making impulsive decisions when I’m feeling high, like spending too much money or talking too fast without thinking. But when the low hits, I regret everything and feel like I’m stuck in a deep hole. I don’t understand why this keeps happening. Could there be something wrong with me?
]
🔸 SHAP Values: [ 0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.        