## **Import Libraries**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline, TextGenerationPipeline
import accelerate
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np
import json
import dice_ml
from dice_ml import Dice
from dice_ml.utils import helpers
import pandas as pd
from typing import List

import shap
import xgboost
import matplotlib.pyplot as plt

from haystack import component, Document, Pipeline, tracing
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever, InMemoryBM25Retriever
from haystack.components.builders import ChatPromptBuilder, PromptBuilder
from haystack.components.joiners import BranchJoiner
from haystack.components.routers import ConditionalRouter
from haystack.components.converters import OutputAdapter
from haystack.components.writers import DocumentWriter
from haystack.document_stores.types import DuplicatePolicy
from haystack.document_stores.in_memory import InMemoryDocumentStore

from haystack.components.generators import HuggingFaceLocalGenerator
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator

from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
from haystack.dataclasses import ChatMessage

import logging
from haystack import tracing
from haystack.tracing.logging_tracer import LoggingTracer

import gradio as gr
import time

Warning: Installing farm-haystack and haystack-ai in the same Python environment (virtualenv, Colab, or system) causes problems.
Installing both packages in the same environment can somehow work or fail in obscure ways. We suggest installing only one of these packages per Python environment. Make sure that you remove both packages if they are installed in the same environment, followed by installing only one of them:

pip uninstall -y farm-haystack haystack-ai

pip install haystack-ai

//pip install git+https://github.com/deepset-ai/haystack.git@main 

## **Data Preparation**

In [2]:
# 1. Data Preparation

# Load dataset
data = load_breast_cancer()
X, y = data.data, data.target
feature_names = data.feature_names

# Split into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Standardize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Create PyTorch Datasets
class CancerDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = CancerDataset(X_train, y_train)
test_dataset = CancerDataset(X_test, y_test)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


## **Define and Train the DNN Model**

In [None]:
# 2. Define and Train the DNN Model

class SimpleDNN(nn.Module):
    def __init__(self, input_dim):
        super(SimpleDNN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.network(x)

input_dim = X_train.shape[1]
model = SimpleDNN(input_dim)

# Loss and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 50
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        outputs = model(X_batch).squeeze()
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(train_loader)
    if (epoch+1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

# Evaluation
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    predictions = []
    for X_batch, y_batch in test_loader:
        outputs = model(X_batch).squeeze()
        predicted = (outputs >= 0.5).int()
        predictions.extend(predicted.numpy())
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy*100:.2f}%")


## **Generate Counterfactual Explanations**

In [4]:
# 3. Generate Counterfactual Explanations Using DiCE

# Prepare data in a pandas DataFrame
X_test_df = pd.DataFrame(X_test, columns=feature_names)
y_test_df = pd.Series(y_test, name='target')
test_df = X_test_df.copy()
test_df['target'] = y_test_df

# Combine train and test for DiCE
X_train_df = pd.DataFrame(X_train, columns=feature_names)
y_train_df = pd.Series(y_train, name='target')
train_df = X_train_df.copy()
train_df['target'] = y_train_df

# Ensure feature_names is a list
if isinstance(feature_names, np.ndarray):
    feature_names = feature_names.tolist()

# Initialize DiCE data object with continuous_features as a list
d = dice_ml.Data(
    dataframe=train_df,
    continuous_features=feature_names,  # Ensure this is a list
    outcome_name='target'
)

# Initialize DiCE model wrapper for PyTorch
class PyTorchModelWrapper(nn.Module):
    def __init__(self, model, scaler):
        super(PyTorchModelWrapper, self).__init__()
        self.model = model
        self.scaler = scaler

    def forward(self, x):
        # If x is a pandas DataFrame or Series, convert it to a NumPy array
        if isinstance(x, (pd.DataFrame, pd.Series)):
            x = x.values

        # Ensure x is a NumPy array
        if isinstance(x, torch.Tensor):
            x = x.detach().numpy()

        # Scale the input data
        x_scaled = self.scaler.transform(x)
        x_tensor = torch.tensor(x_scaled, dtype=torch.float32)

        # Pass the scaled data through the model
        outputs = self.model(x_tensor)
        return outputs

    def predict_proba(self, x):
        # Use the forward method to get model outputs
        outputs = self.forward(x).detach().numpy().squeeze()

        # Since the model outputs probabilities for class 1, compute class 0 probabilities
        probs_class_1 = outputs
        probs_class_0 = 1 - probs_class_1

        # Stack probabilities
        probs = np.vstack([probs_class_0, probs_class_1]).T
        return probs

# Instantiate the model wrapper
model_wrapper = PyTorchModelWrapper(model, scaler)

# Initialize DiCE model with correct backend
m = dice_ml.Model(model=model_wrapper, backend='PYT', model_type='classifier')

# Initialize DiCE explainer
explainer_DiCE = Dice(d, m)

## **Generate SHAP Explanations**

In [None]:
# 4. Create SHAP explanation, global + local

# Convert the PyTorch model to a SHAP-compatible model
class ShapModelWrapper:
    def __init__(self, model, scaler):
        self.model = model
        self.scaler = scaler

    def predict(self, X):
        # Ensure input is in the right format
        if isinstance(X, pd.DataFrame) or isinstance(X, pd.Series):
            X = X.values
        X_scaled = self.scaler.transform(X)
        X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
        with torch.no_grad():
            output = self.model(X_tensor).numpy().squeeze()
        return output

# Instantiate the SHAP model wrapper
shap_model_wrapper = ShapModelWrapper(model, scaler)

# Select a subset of the data to use with SHAP to avoid memory issues
X_shap = X_test_df

# Generate SHAP explainer and values
explainer_SHAP = shap.Explainer(shap_model_wrapper.predict, X_train_df)
shap_values = explainer_SHAP(X_shap)

# SHAP Summary Plot
plt.figure()
shap.summary_plot(shap_values, X_shap, show=False)
plt.savefig('shap_summary_plot.png', bbox_inches='tight')
plt.show()

# # SHAP Feature Importance Plot
clust = shap.utils.hclust(X, y, linkage="single")
plt.figure()
shap.plots.bar(shap_values, max_display = 31)
plt.savefig('shap_feature_importance_plot.png', bbox_inches='tight')
plt.show()

# SHAP Feature Importance Plot + Clustering
clust = shap.utils.hclust(X, y, linkage="single")
plt.figure()
shap.plots.bar(shap_values, clustering=clust, clustering_cutoff=0.5, max_display = 31)
plt.savefig('shap_feature_importance_clustering_plot.png', bbox_inches='tight')
plt.show()

this_y_pred = (shap_values.values.sum(1) + shap_values[0].base_values) > 0
this_misclassified = this_y_pred != y_test_df
shap.decision_plot(base_value=shap_values[0].base_values, shap_values=shap_values.values[5], feature_names = feature_names, highlight=this_misclassified[5], alpha=0.5)
shap.decision_plot(base_value=shap_values[0].base_values, shap_values=shap_values.values[0:], feature_names = feature_names, highlight=this_misclassified[0:], feature_display_range=slice(None, None, -1), alpha=0.5)
shap.decision_plot(base_value=shap_values[0].base_values, shap_values=shap_values.values, feature_names = feature_names, highlight=this_misclassified, feature_display_range=slice(None, None, -1), alpha=0.5, feature_order="hclust")


# Store SHAP values in a DataFrame
shap_df = pd.DataFrame(shap_values.values, columns=X_shap.columns)
shap_df['expected_value'] = shap_values.base_values
shap_df['observation_index'] = X_shap.index


print(shap_df)

## **Create Document Store**

In [7]:
# 5. Initialize Retrieval-Augmented Generation (RAG) Components
document_store = InMemoryDocumentStore()
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
doc_embedder.warm_up()

In [8]:
def generate_and_store_explanations(index, test_df, shap_values, feature_names, explainer_CF, model_wrapper, doc_embedder, document_store, exclude_global_explanations=True):
    # Get the test instance by index
    instance = test_df.iloc[index]
    query = instance[feature_names].to_frame().transpose()  # Convert to DataFrame

    # Get the prediction for the original instance using the model wrapper
    prediction_probs = model_wrapper.predict_proba(query)
    prediction = int(prediction_probs[0, 1] >= 0.5)

    # Create the original document
    original_doc = Document(
        # content=json.dumps(dict(zip(feature_names, instance[feature_names].values)), indent=2),
        content=pd.concat([pd.Series({
                    "case_index": index,
                    "document content type": "Original instance",
                    "scope": "local",
                    "prediction": 'Malignant' if prediction == 1 else 'Benign'
                }),pd.DataFrame([instance[feature_names].values], columns=feature_names).T])[0].to_json(indent=2),
        meta={
            "case_index": index,
            "type": "original",
            "prediction": 'Malignant' if prediction == 1 else 'Benign'
        }
    )
    docs = [original_doc]    

    # Generate SHAP explanations    
    shap_values_raw = shap_values.values

    # Add local SHAP explanation as a document
    shap_values_instance = pd.DataFrame(shap_values_raw, columns=feature_names).iloc[index]
    shap_local_doc = Document(
                content=pd.concat([pd.Series({
                    "case_index": index,
                    "document content type": "SHAP explanation",
                    "scope": "local",
                    "prediction": 'Malignant' if prediction == 1 else 'Benign'
                }),shap_values_instance]).to_json(indent=2),
                meta={
                    "source": "shap_explanation_local",
                    "case_index": index,
                    "type": "shap",
                    "scope": "local",
                    "prediction": 'Malignant' if prediction == 1 else 'Benign'
                }
            )
    docs.append(shap_local_doc)   

    # base_values = shap_values[0].base_values

    # Calculate predicted classes
    # this_y_pred_raw = shap_values_raw.sum(1) + base_values
    # this_y_pred_binary = this_y_pred_raw > 0

    # Split SHAP values by predicted class
    # shap_values_class_0 = shap_values_raw[this_y_pred_binary == 0]
    # shap_values_class_1 = shap_values_raw[this_y_pred_binary == 1]

    # Function to calculate SHAP metrics
    # def calculate_shap_metrics(shap_values_subset):
    #     mean_shap = np.mean(shap_values_subset, axis=0)
    #     std_shap = np.std(shap_values_subset, axis=0)
    #     min_shap = np.min(shap_values_subset, axis=0)
    #     max_shap = np.max(shap_values_subset, axis=0)
    #     return mean_shap, std_shap, min_shap, max_shap

    # Calculate metrics for each class
    # mean_shap_0, std_shap_0, min_shap_0, max_shap_0 = calculate_shap_metrics(shap_values_class_0)
    # mean_shap_1, std_shap_1, min_shap_1, max_shap_1 = calculate_shap_metrics(shap_values_class_1)

    # Correlation between SHAP values and target for each class
    # y_test_class_0 = y_test_df[this_y_pred_binary == 0]
    # y_test_class_1 = y_test_df[this_y_pred_binary == 1]

    # corr_shap_target_0 = [np.corrcoef(shap_values_class_0[:, i], y_test_class_0)[0, 1] for i in range(len(feature_names))]
    # corr_shap_target_1 = [np.corrcoef(shap_values_class_1[:, i], y_test_class_1)[0, 1] for i in range(len(feature_names))]

    # Store results for class 0
    # mean_shap_df_0 = pd.DataFrame(mean_shap_0, index=feature_names, columns=["Mean SHAP Value - Class 0"])
    # std_shap_df_0 = pd.DataFrame(std_shap_0, index=feature_names, columns=["Std Dev SHAP Value - Class 0"])
    # min_shap_df_0 = pd.DataFrame(min_shap_0, index=feature_names, columns=["Min SHAP Value - Class 0"])
    # max_shap_df_0 = pd.DataFrame(max_shap_0, index=feature_names, columns=["Max SHAP Value - Class 0"])
    # corr_shap_target_df_0 = pd.DataFrame(corr_shap_target_0, index=feature_names, columns=["Correlation with Target - Class 0"])

    # # Store results for class 1
    # mean_shap_df_1 = pd.DataFrame(mean_shap_1, index=feature_names, columns=["Mean SHAP Value - Class 1"])
    # std_shap_df_1 = pd.DataFrame(std_shap_1, index=feature_names, columns=["Std Dev SHAP Value - Class 1"])
    # min_shap_df_1 = pd.DataFrame(min_shap_1, index=feature_names, columns=["Min SHAP Value - Class 1"])
    # max_shap_df_1 = pd.DataFrame(max_shap_1, index=feature_names, columns=["Max SHAP Value - Class 1"])
    # corr_shap_target_df_1 = pd.DataFrame(corr_shap_target_1, index=feature_names, columns=["Correlation with Target - Class 1"])

    # Calculate pairwise correlations between SHAP values for each class
    # corr_shap_features_0 = pd.DataFrame(shap_values_class_0, columns=feature_names).corr().stack().reset_index()
    # corr_shap_features_1 = pd.DataFrame(shap_values_class_1, columns=feature_names).corr().stack().reset_index()

    # if not exclude_global_explanations:
    #     # Add global SHAP explanation as a document
    #     shap_global_mean_0_doc = Document(
    #         content=mean_shap_df_0.to_json(indent=2),
    #         meta={
    #             "source": "shap_explanation_global",
    #             "case_index": index,
    #             "type": "shap",
    #             "scope": "global",
    #             "prediction": 'Benign'
    #         }
    #     )
    #     shap_global_mean_1_doc = Document(
    #         content=mean_shap_df_1.to_json(indent=2),
    #         meta={
    #             "source": "shap_explanation_global",
    #             "case_index": index,
    #             "type": "shap",
    #             "scope": "global",
    #             "prediction": 'Malignant'
    #         }
    #     )
    #     docs.append(shap_global_mean_0_doc)
    #     docs.append(shap_global_mean_1_doc)

    #     shap_global_std_0_doc = Document(
    #         content=std_shap_df_0.to_json(indent=2),
    #         meta={
    #             "source": "shap_explanation_global",
    #             "case_index": index,
    #             "type": "shap",
    #             "scope": "global",
    #             "prediction": 'Benign'
    #         }
    #     )
    #     shap_global_std_1_doc = Document(
    #         content=std_shap_df_1.to_json(indent=2),
    #         meta={
    #             "source": "shap_explanation_global",
    #             "case_index": index,
    #             "type": "shap",
    #             "scope": "global",
    #             "prediction": 'Malignant'
    #         }
    #     )
    #     docs.append(shap_global_std_0_doc)
    #     docs.append(shap_global_std_1_doc)

    #     shap_global_min_0_doc = Document(
    #         content=min_shap_df_0.to_json(indent=2),
    #         meta={
    #             "source": "shap_explanation_global",
    #             "case_index": index,
    #             "type": "shap",
    #             "scope": "global",
    #             "prediction": 'Benign'
    #         }
    #     )
    #     shap_global_min_1_doc = Document(
    #         content=min_shap_df_1.to_json(indent=2),
    #         meta={
    #             "source": "shap_explanation_global",
    #             "case_index": index,
    #             "type": "shap",
    #             "scope": "global",
    #             "prediction": 'Malignant'
    #         }
    #     )
    #     docs.append(shap_global_min_0_doc)
    #     docs.append(shap_global_min_1_doc)

    #     shap_global_max_0_doc = Document(
    #         content=max_shap_df_0.to_json(indent=2),
    #         meta={
    #             "source": "shap_explanation_global",
    #             "case_index": index,
    #             "type": "shap",
    #             "scope": "global",
    #             "prediction": 'Benign'
    #         }
    #     )
    #     shap_global_max_1_doc = Document(
    #         content=max_shap_df_1.to_json(indent=2),
    #         meta={
    #             "source": "shap_explanation_global",
    #             "case_index": index,
    #             "type": "shap",
    #             "scope": "global",
    #             "prediction": 'Malignant'
    #         }
    #     )
    #     docs.append(shap_global_max_0_doc)
    #     docs.append(shap_global_max_1_doc)

    #     shap_global_corr_0_doc = Document(
    #         content=corr_shap_target_df_0.to_json(indent=2),
    #         meta={
    #             "source": "shap_explanation_global",
    #             "case_index": index,
    #             "type": "shap",
    #             "scope": "global",
    #             "prediction": 'Benign'
    #         }
    #     )
    #     shap_global_corr_1_doc = Document(
    #         content=corr_shap_target_df_1.to_json(indent=2),
    #         meta={
    #             "source": "shap_explanation_global",
    #             "case_index": index,
    #             "type": "shap",
    #             "scope": "global",
    #             "prediction": 'Malignant'
    #         }
    #     )
    #     docs.append(shap_global_corr_0_doc)
    #     docs.append(shap_global_corr_1_doc)
    

    # Generate counterfactual explanations
    explanation = explainer_CF.generate_counterfactuals(
        query_instances=query,
        total_CFs=3,
        desired_class='opposite',
        features_to_vary='all'
    )

    # Access the counterfactuals DataFrame
    cf_df = explanation.cf_examples_list[0].final_cfs_df

    # Iterate over each counterfactual and create a Document for each
    for idx, cf_row in cf_df.iterrows():
        cf_values = cf_row[feature_names].values.tolist()

        # Get the prediction for the counterfactual instance
        cf_query = pd.DataFrame([cf_values], columns=feature_names)
        cf_prediction_raw = model_wrapper.predict_proba(cf_query)
        cf_prediction = int(cf_prediction_raw[0, 1] >= 0.5)
        
        # Compute changed features
        changed_features = [
            feature for feature in feature_names
            if not np.isclose(cf_row[feature], instance[feature])
        ]
        
        cf_values_changed = cf_row[changed_features].values.tolist()
        cf_query = pd.DataFrame([cf_values_changed], columns=changed_features).T[0]
        
        # Create a Document for the counterfactual instance
        cf_doc = Document(
            # content=json.dumps(dict(zip(feature_names, cf_values)), indent=2),
            # content=pd.DataFrame(cf_values, columns=feature_names),
            # content=pd.concat([pd.Series({
            #         "case_index": index,
            #         "document content type": "Counterfactual Explanation",
            #         "scope": "local",
            #         "changed features": changed_features,
            #         "number of changed features": len(changed_features),
            #         "original prediction": 'Malignant' if prediction == 1 else 'Benign',
            #         "counterfactual prediction": 'Malignant' if cf_prediction == 1 else 'Benign'
            #     }),cf_query.T])[0].to_json(indent=2),
            content=pd.Series({
                    "case_index": index,
                    "document content type": "Counterfactual explanation",
                    "scope": "local",
                    "changed features": changed_features,
                    "number of changed features": len(changed_features),
                    "original prediction": 'Malignant' if prediction == 1 else 'Benign',
                    "counterfactual prediction": 'Malignant' if cf_prediction == 1 else 'Benign',
                    "feature values changed to": cf_query
                }).to_json(indent=2),
            meta={
                "source": f"counterfactual_instance",
                "case_index": index,
                "type": "counterfactual",
                "scope": "local",
                "changed features": changed_features,
                "number of changed features": len(changed_features),
                "original prediction": 'Malignant' if prediction == 1 else 'Benign',
                "counterfactual prediction": 'Malignant' if cf_prediction == 1 else 'Benign'
            }
        )
        docs.append(cf_doc)

    # Embed the documents
    docs_with_embeddings = doc_embedder.run(docs)

    # Write the documents to the document store
    document_store.write_documents(docs_with_embeddings["documents"], policy=DuplicatePolicy.SKIP)
    # document_store.write_documents(documents=docs, policy=DuplicatePolicy.SKIP)

# Example usage:
# generate_and_store_explanations(index, test_df, shap_local, shap_global, feature_names, explainer_CF, model_wrapper, doc_embedder, document_store)

In [None]:
generate_and_store_explanations(index=1, test_df=test_df, shap_values=shap_values, feature_names=feature_names, explainer_CF=explainer_DiCE, model_wrapper=model_wrapper, doc_embedder=doc_embedder, document_store=document_store)

In [None]:
this_filters = {
    "operator": "AND",
    "conditions": [
        {"field": "source", "operator": "==", "value": "shap_explanation_local"}        
    ],
}

# print(document_store.filter_documents()[0].content)
# print(document_store.filter_documents()[1].content)
# print(document_store.filter_documents()[2].content)
# print(document_store.filter_documents()[3].content)
# print(document_store.filter_documents())
print(document_store.filter_documents(filters=this_filters))

## **Create RAG Pipeline**

In [11]:
# Enable logging
# logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
# logging.getLogger("haystack").setLevel(logging.DEBUG)

# tracing.tracer.is_content_tracing_enabled = True
# tracing.enable_tracing(LoggingTracer(tags_color_strings={"haystack.component.input": "\x1b[1;31m", "haystack.component.name": "\x1b[1;34m"}))

In [None]:
# Load LLM 
# generator = HuggingFaceLocalChatGenerator(
generator = HuggingFaceLocalGenerator(
    huggingface_pipeline_kwargs={"device":0, "torch_dtype":torch.bfloat16},
    generation_kwargs={"max_new_tokens": 1024, "max_time":120},
    model="meta-llama/Llama-3.2-3B-Instruct",
    # model="meta-llama/Llama-3.2-3B"
    # model="meta-llama/Llama-2-7b-chat-hf"
    # model="HuggingFaceH4/zephyr-7b-beta"
    )
generator.warm_up()

# Create memory store for chat messages
memory_store = InMemoryChatMessageStore()

rag_pipe_main = Pipeline()

## Main RAG
# components for RAG
rag_pipe_main.add_component("text_embedder_main", SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"))
rag_pipe_main.add_component("retriever_main", InMemoryEmbeddingRetriever(document_store, top_k=5))
rag_pipe_main.add_component("prompt_builder_main", ChatPromptBuilder(variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"]))

# components for memory
rag_pipe_main.add_component("memory_retriever_main", ChatMessageRetriever(memory_store))
rag_pipe_main.add_component("memory_joiner_main", BranchJoiner(List[ChatMessage]))
rag_pipe_main.add_component("memory_writer_main", ChatMessageWriter(memory_store))


# connections for RAG
rag_pipe_main.connect("text_embedder_main", "retriever_main")
rag_pipe_main.connect("retriever_main.documents", "prompt_builder_main.documents")

# connections for memory
rag_pipe_main.connect("memory_joiner_main", "memory_writer_main")
rag_pipe_main.connect("memory_retriever_main", "prompt_builder_main.memories")

## Chat memory module
rag_pipe_memory_writer = Pipeline()

# components for memory
rag_pipe_memory_writer.add_component("memory_joiner", BranchJoiner(List[ChatMessage]))
rag_pipe_memory_writer.add_component("memory_writer", ChatMessageWriter(memory_store))

# connections for memory
rag_pipe_memory_writer.connect("memory_joiner", "memory_writer")



In [None]:
rag_pipe_main.show()

In [None]:
rag_pipe_memory_writer.show()

In [15]:
system_message = ChatMessage.from_system("""You are a helpful AI assistant using provided supporting documents and conversation history to assist humans.""")

user_message_template ="""<|begin_of_text|><|start_header_id|>user<|end_header_id|>Given the conversation history and the provided supporting documents, fulfill the task or give an answer to the question.
Note that supporting documents are not part of the conversation. If questions can't be answered or tasks cannot be fulfilled by using supporting documents, state this in your response.

The following are criteria for counterfactual explanations:
Validity: The counterfactual must accurately reflect changes that would lead to a different diagnostic prediction by the model.
Proximity: Changes should be minimal, keeping the counterfactual instance close to the original case.
Sparsity: The explanation should involve altering as few features as possible.
Feasibility: Proposed changes must be medically plausible, even if not actionable.

SHAP explanations provide information regarding the contribution of each feature on the prediction.

Supporting documents contain information regarding the conducted analysis of human breast tissue with a suspision for breast cancer along with a prediction from a specialized machine learning model.
Evaluate all counterfactual explanations yielded by the machine learning model within the provided documents using criteria for counterfactual explanations focusing on the changed features and feature values.
By default, focus your response on analysis regarding feature impact using SHAP explanations and connect your observations to your evaluation of counterfactual explanations.

    Conversation history:
    {% for memory in memories %}
        {{ memory.content }}
    {% endfor %}

    Supporting documents:
    {% for doc in documents %}
        {{ doc.content }}
    {% endfor %}
    
    Documents with the document content type "Original instance" contain the original feature values for each feature variable.
    Documents with the document content type "Counterfactual explanation" contain changed feature values for each feature variable.
    Documents with the document content type "SHAP explanation" contain SHAP values for each feature variable instead of feature values. These SHAP values show the impact of the each feature and its value on the model prediction for the original instance.

    \\Task or Question: {{query}}
    \\Response:<|eot_id|>
  <|start_header_id|>assistant<|end_header_id|>
"""

user_message = ChatMessage.from_user(user_message_template)
this_chat_template = [system_message, user_message]

In [19]:
def run_demo(this_index):
    # this_index = 1
    document_store = InMemoryDocumentStore()
    generate_and_store_explanations(index=this_index, test_df=test_df, shap_values=shap_values, feature_names=feature_names, explainer_CF=explainer_DiCE, model_wrapper=model_wrapper, doc_embedder=doc_embedder, document_store=document_store)
    doc_filter = {
        "operator": "AND",
        "conditions": [
            {"field": "case_index", "operator": "==", "value": this_index} # only retrieve Documents for the relevant data instance        
        ],
    }

    def this_chatbot(message, history):
        res = rag_pipe_main.run(
                data={
                    "retriever_main": {"filters": doc_filter},
                    "text_embedder_main": {"text": message},
                    "prompt_builder_main": {"template": this_chat_template, "query": message},
                    "memory_joiner_main": {"value": [ChatMessage.from_user(message)] },
                    "memory_retriever_main": {"last_k": 5}            
                }
            )
        
        prompt_for_generator = res["prompt_builder_main"]["prompt"][1].content
        gen_res = generator.run(prompt_for_generator)
        rag_pipe_memory_writer.run(
                data={
                    "memory_joiner": {"value": [ChatMessage.from_assistant(gen_res["replies"][0])] }
                }

            ) 
        
        return gen_res["replies"][0]
    
    demo = gr.ChatInterface(
        chatbot=gr.Chatbot(height=800, placeholder="<strong>LLM-assisted exploration of local XAI analysis. </strong><br>Ask Me Anything"),
        fn=this_chatbot,
        # inputs=["text", gr.Slider(0, 100)],
        examples=[
            "Which of the counterfactual explanations fits the best in order to explain the model prediction?",
            "Explain the model prediction using only local counterfactual explanations.",
            "Explain the model prediction using only local SHAP explanations.",
        ],
        title="LLM powered XAI",
        description="Explore counterfactual explanation and SHAP explanations for individual test instances.",
        theme="soft",
    )   

    demo.launch()



In [None]:
run_demo(this_index= 1)

