<a href="https://colab.research.google.com/github/BianchiLuca28/XAI-Healthcare/blob/main/notebooks/llm_call.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports and setup

In [26]:
!pip install python-dotenv groq

Collecting groq
  Downloading groq-0.15.0-py3-none-any.whl.metadata (14 kB)
Downloading groq-0.15.0-py3-none-any.whl (109 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.6/109.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: groq
Successfully installed groq-0.15.0


In [27]:
from dotenv import load_dotenv
import os
import requests
from groq import Groq

# Setting up LLM

In [35]:
def llm(query, token, model):
    """
    Query an LLM using the Hugging Face Inference API.

    Parameters:
        query (str): The input query.
        token (str): Hugging Face API token.
        model (str): Model.

    Returns:
        str: Generated response from the LLM.
    """
    parameters = {
        "max_new_tokens": 1000,
        "temperature": 0.1,
        "top_k": 50,
        "top_p": 0.95,
        "return_full_text": False
    }

    headers = {
        'Authorization': f'Bearer {token}',
        'Content-Type': 'application/json'
    }

    payload = {
        "inputs": query,
        "parameters": parameters
    }

    response = requests.post(f"https://api-inference.huggingface.co/models/{model}", headers=headers, json=payload)

    if response.status_code != 200:
        raise Exception(f"API request failed with status {response.status_code}: {response.text}")

    return response.json()[0].get('generated_text', '').strip()

In [29]:
load_dotenv()

False

In [30]:
# token = "HUGGING_FACE"
model = "mistralai/Mistral-Nemo-Instruct-2407"

# Prompting

In [36]:
def format_shap_prompt(customer_data, prediction, prediction_proba):
    prompt = ""
    risk_label = "high risk" if prediction == 1 else "low risk"
    features_and_shap_values = ""

    # Iterate through each feature in the SHAP data
    for _, row in customer_data.iterrows():
        feature = row["Feature"]
        feature_value = row["Feature Value"]
        shap_value = row["SHAP Value"]

        # Here we add the SHAP and feature information
        if isinstance(feature_value, (int, float)):
            features_and_shap_values += f"- {feature}: {feature_value:.2f} (SHAP impact: {shap_value:.2f})\n"
        else:
            features_and_shap_values += f"- {feature}: {feature_value} (SHAP impact: {shap_value:.2f})\n"

    prompt = f"""
You are a smart and helpful explainer and interpreter for a machine learning model that classifies customers as high or low risk regarding credit default.

The customer in question has been classified as {risk_label} with a probability of {prediction_proba:.2f}.
Below are the customer's features and their corresponding SHAP values:

{features_and_shap_values}

Instructions:
1. Provide only the main reasons the customer was classified as {risk_label} by referencing the most impactful features and the respective reasons behind it.
2. Do not include any disclaimers, contact information, or explanations of what SHAP values are. You should analyze the SHAP values in relation to the feature values and the connections between them, but the values themselves don't have to be mentioned in the reply.
3. Write your answer without extra salutations, sign-offs and mentions of the SHAP values.
4. Write your answer using bullet points for the features you want to mention.

Take your time to thoroughly analyze the values and the connections. Note that the false flags for the status are because of the one hot encoding, meaning that the true flag was the one representative of the analyzed customer.
Please explain the primary factors that led to this classification:
"""

    return prompt

In [37]:
prompt = f"""
You are a smart and helpful explainer and interpreter for a machine learning model that explains Myocardial infarction complications.

Given this features of the dataset, explain each of them briefly:
ID,AGE,SEX,INF_ANAM,STENOK_AN,FK_STENOK,IBS_POST,IBS_NASL,GB,SIM_GIPERT,DLIT_AG,ZSN_A,nr_11,nr_01,nr_02,nr_03,nr_04,nr_07,nr_08,np_01,np_04,np_05,np_07,np_08,np_09,np_10,endocr_01,endocr_02,endocr_03,zab_leg_01,zab_leg_02,zab_leg_03,zab_leg_04,zab_leg_06,S_AD_KBRIG,D_AD_KBRIG,S_AD_ORIT,D_AD_ORIT,O_L_POST,K_SH_POST,MP_TP_POST,SVT_POST,GT_POST,FIB_G_POST,ant_im,lat_im,inf_im,post_im,IM_PG_P,ritm_ecg_p_01,ritm_ecg_p_02,ritm_ecg_p_04,ritm_ecg_p_06,ritm_ecg_p_07,ritm_ecg_p_08,n_r_ecg_p_01,n_r_ecg_p_02,n_r_ecg_p_03,n_r_ecg_p_04,n_r_ecg_p_05,n_r_ecg_p_06,n_r_ecg_p_08,n_r_ecg_p_09,n_r_ecg_p_10,n_p_ecg_p_01,n_p_ecg_p_03,n_p_ecg_p_04,n_p_ecg_p_05,n_p_ecg_p_06,n_p_ecg_p_07,n_p_ecg_p_08,n_p_ecg_p_09,n_p_ecg_p_10,n_p_ecg_p_11,n_p_ecg_p_12,fibr_ter_01,fibr_ter_02,fibr_ter_03,fibr_ter_05,fibr_ter_06,fibr_ter_07,fibr_ter_08,GIPO_K,K_BLOOD,GIPER_NA,NA_BLOOD,ALT_BLOOD,AST_BLOOD,KFK_BLOOD,L_BLOOD,ROE,TIME_B_S,R_AB_1_n,R_AB_2_n,R_AB_3_n,NA_KB,NOT_NA_KB,LID_KB,NITR_S,NA_R_1_n,NA_R_2_n,NA_R_3_n,NOT_NA_1_n,NOT_NA_2_n,NOT_NA_3_n,LID_S_n,B_BLOK_S_n,ANT_CA_S_n,GEPAR_S_n,ASP_S_n,TIKL_S_n,TRENT_S_n,FIBR_PREDS,PREDS_TAH,JELUD_TAH,FIBR_JELUD,A_V_BLOK,OTEK_LANC,RAZRIV,DRESSLER,ZSN,REC_IM,P_IM_STEN,LET_IS
"""

In [38]:
# formatted_prompt = format_shap_prompt(
#     customer_data,
#     prediction=1,  # customer classified as high risk
#     prediction_proba=0.85)

# print(formatted_prompt)

try:
    response = llm(prompt, token, model)
    print("Generated Response:\n", response)
except Exception as e:
    print(e)

Generated Response:
 Here are the explanations:

1. **ID**: Unique identifier for each patient.
2. **AGE**: Patient's age in years.
3. **SEX**: Patient's sex (0 - Female, 1 - Male).
4. **INF_ANAM**: History of previous myocardial infarction (0 - No, 1 - Yes).
5. **STENOK_AN**: History of coronary artery disease (0 - No, 1 - Yes).
6. **FK_STENOK**: Time since the last coronary artery disease diagnosis in years.
7. **IBS_POST**: Post-infarction angina (0 - No, 1 - Yes).
8. **IBS_NASL**: Nasal congestion (0 - No, 1 - Yes).
9. **GB**: Body mass index (BMI) category (0 - Underweight, 1 - Normal, 2 - Overweight, 3 - Obese).
10. **SIM_GIPERT**: History of hypertension (0 - No, 1 - Yes).
11. **DLIT_AG**: Age at the onset of diabetes mellitus in years.
12. **ZSN_A**: Smoking status (0 - Non-smoker, 1 - Smoker).
13. **nr_01** to **nr_11**, **np_01** to **np_10**, **endocr_01** to **endocr_03**, **zab_leg_01** to **zab_leg_06**: These are binary features representing various medical conditions, s

# Using Groq

In [34]:
client = Groq(
    # api_key="GROQ_API_KEY"
)

chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": prompt,
        }
    ],
    model="llama-3.3-70b-versatile",
)

print(chat_completion.choices[0].message.content)

This dataset appears to be related to Myocardial Infarction (MI), commonly known as a heart attack, and its complications. The features can be broadly categorized into several groups: patient demographics, medical history, clinical measurements, lab tests, medication, and post-MI complications. Here's a brief explanation of each feature:

**Patient Demographics:**

1. ID: Unique patient identifier
2. AGE: Patient's age
3. SEX: Patient's sex (male or female)

**Medical History:**

1. INF_ANAM: Anamnesis of infarction (previous heart attack)
2. STENOK_AN: Anamnesis of stenocardia (chest pain)
3. FK_STENOK: Frequency of stenocardia attacks
4. IBS_POST: Post-infarction syndrome
5. IBS_NASL: Family history of cardiovascular disease
6. GB: History of hypertension
7. SIM_GIPERT: Severe hypertension

**Clinical Measurements:**

1. DLIT_AG: Duration of angina pectoris (chest pain)
2. ZSN_A: Presence of Zhdań arteriosclerosis (hardening of the arteries)
3. nr_11 to nr_08: Unspecified clinical me