In [None]:
!pip install pandas openpyxl setfit sentence-transformers datasets joblib torch groq

In [None]:
import pandas as pd
import joblib
from datasets import Dataset
from setfit import SetFitModel, SetFitTrainer
from sentence_transformers.losses import CosineSimilarityLoss
import os
import numpy as np
from sklearn.model_selection import train_test_split
from groq import Groq # Import Groq client
import time # To add slight delay if needed for API rate limits

In [None]:
import os
os.environ['GROQ_API_KEY'] = #Place your GROQ_API_KEY here

In [None]:
# --- Configuration ---
# *** Please ensure this path is correct ***
file_path = r"C:\Users\punee\Downloads\final_diseases_dataset.xlsx"
model_save_path = "disease_predictor_setfit_model"
label_mapping_save_path = "disease_label_mapping.pkl"
# Note: We are removing the disease_info save path, as Groq will provide this info live.

# Training parameters
PRETRAINED_MODEL_NAME = "sentence-transformers/paraphrase-mpnet-base-v2" # Base model for SetFit
NUM_ITERATIONS = 20
BATCH_SIZE = 16
NUM_EPOCHS_HEAD = 1
TEST_SET_SIZE = 0.2

# Groq Configuration
# Using Mixtral via Groq by default - generally good balance of speed/capability
# Other options: "llama3-8b-8192", "llama3-70b-8192", "gemma-7b-it"
GROQ_MODEL_NAME = "meta-llama/llama-4-maverick-17b-128e-instruct"

In [None]:
# --- Functions ---

def load_and_prepare_data_for_setfit(file_path):
    """Loads data, checks columns, combines text for SetFit training."""
    print(f"Loading dataset from: {file_path}")
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Error: Dataset file not found at {file_path}")

    df = pd.read_excel(file_path)
    print("Dataset loaded successfully.")
    print("Columns found:", df.columns.tolist())

    # Define required columns for SetFit training
    required_text_cols = ['Overview/Definition', 'Symptoms', 'Causes/Risk Factors']
    required_label_col = 'Disease Name'
    all_required_cols = required_text_cols + [required_label_col]

    missing_cols = [col for col in all_required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Error: Missing required columns for SetFit training: {missing_cols}.")

    print("Required columns for training found.")

    # Combine text features
    print("Combining text features...")
    df['combined_text'] = df[required_text_cols].fillna('').agg(' '.join, axis=1)
    df['combined_text'] = df['combined_text'].str.replace(r'\s+', ' ', regex=True).str.strip()

    # Prepare labels
    df[required_label_col] = df[required_label_col].astype(str)
    labels = df[required_label_col]
    label_categories = labels.astype("category")
    label_codes = label_categories.cat.codes
    label_mapping_int_to_name = dict(enumerate(label_categories.cat.categories))

    print(f"Found {len(label_mapping_int_to_name)} unique diseases for classification.")

    # Create Hugging Face Dataset (only text and label needed for training)
    hf_dataset = Dataset.from_dict({
        "text": df['combined_text'].tolist(),
        "label": label_codes.tolist()
    })

    return hf_dataset, label_mapping_int_to_name
    # No longer returning disease_info_lookup from here

def train_disease_classifier(train_dataset, eval_dataset, label_mapping):
    """Trains the SetFit model for disease classification."""
    print(f"\n--- Starting SetFit Classifier Training using {PRETRAINED_MODEL_NAME} ---")
    num_classes = len(label_mapping)
    print(f"Number of classes (diseases): {num_classes}")

    model = SetFitModel.from_pretrained(PRETRAINED_MODEL_NAME)

    trainer = SetFitTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss_class=CosineSimilarityLoss,
        metric="accuracy",
        batch_size=BATCH_SIZE,
        num_iterations=NUM_ITERATIONS,
        num_epochs=NUM_EPOCHS_HEAD,
        column_mapping={"text": "text", "label": "label"}
    )

    print("Training SetFit classifier...")
    trainer.train()

    if eval_dataset:
      print("Evaluating classifier performance...")
      metrics = trainer.evaluate()
      print(f"Evaluation Metrics: {metrics}")
    else:
      metrics = None
      print("No evaluation dataset provided, skipping evaluation.")


    print("Classifier training complete.")
    return model, metrics

def get_info_from_groq(disease_name, symptoms_context=""):
    """Queries Groq LLM for treatment and prevention info."""
    print(f"\n--- Querying Groq for info on: {disease_name} ---")
    try:
        api_key = os.environ.get("GROQ_API_KEY")
        if not api_key:
            raise ValueError("GROQ_API_KEY environment variable not set.")

        client = Groq(api_key=api_key)

        # Construct a prompt for the LLM
        prompt_parts = [
            f"You are a helpful assistant providing general medical information.",
            f"Based on the disease identified as '{disease_name}', please provide:",
            f"1. General Treatment Options: (Common approaches, medications, therapies - mention this is not specific advice)",
            f"2. General Prevention Measures: (Lifestyle advice, precautions)",
            f"\nContext (Symptoms provided by user, if any): '{symptoms_context}'" if symptoms_context else "",
            f"\nIMPORTANT: Frame the response clearly separating Treatment and Prevention. State explicitly that this information is general and not a substitute for professional medical advice.",
            f"Keep the response concise and informative."
         ]
        prompt = "\n".join(prompt_parts)


        print(f"Sending request to Groq model: {GROQ_MODEL_NAME}...")
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            model=GROQ_MODEL_NAME,
            # Optional parameters:
            # temperature=0.7, # Adjust creativity vs. factuality
            # max_tokens=500, # Limit response length
            # top_p=1,
            # stop=None,
            # stream=False,
        )

        response_content = chat_completion.choices[0].message.content
        print("Groq response received.")

        # Basic parsing attempt (can be improved based on observed LLM output format)
        treatment_info = "Could not parse Treatment info from response."
        prevention_info = "Could not parse Prevention info from response."

        # Look for keywords to split - this is fragile and depends on LLM consistency
        response_lower = response_content.lower()
        treatment_keywords = ["treatment options:", "general treatment:", "treatment:", "1. general treatment options:"]
        prevention_keywords = ["prevention measures:", "general prevention:", "prevention:", "2. general prevention measures:"]

        treatment_start_idx = -1
        prevention_start_idx = -1

        for kw in treatment_keywords:
             idx = response_lower.find(kw)
             if idx != -1:
                 treatment_start_idx = idx + len(kw)
                 break

        for kw in prevention_keywords:
             idx = response_lower.find(kw)
             if idx != -1:
                 prevention_start_idx = idx + len(kw)
                 break


        if treatment_start_idx != -1 and prevention_start_idx != -1:
            # Assume treatment comes before prevention
            if treatment_start_idx < prevention_start_idx:
                 # Find the start of the prevention section to mark the end of treatment
                 end_of_treatment_idx = response_lower.find(next(kw for kw in prevention_keywords if kw in response_lower), treatment_start_idx)
                 if end_of_treatment_idx != -1:
                      treatment_info = response_content[treatment_start_idx:end_of_treatment_idx].strip()
                 else: # Prevention keyword found, but maybe not after treatment keyword search start
                      treatment_info = response_content[treatment_start_idx:].strip() # Take rest of string, hoping prevention is next

                 prevention_info = response_content[prevention_start_idx:].strip()
            else: # Assume prevention comes before treatment (less common based on prompt)
                 end_of_prevention_idx = response_lower.find(next(kw for kw in treatment_keywords if kw in response_lower), prevention_start_idx)
                 if end_of_prevention_idx != -1:
                      prevention_info = response_content[prevention_start_idx:end_of_prevention_idx].strip()
                 else:
                      prevention_info = response_content[prevention_start_idx:].strip()

                 treatment_info = response_content[treatment_start_idx:].strip()

        elif treatment_start_idx != -1:
            treatment_info = response_content[treatment_start_idx:].strip()
        elif prevention_start_idx != -1:
            prevention_info = response_content[prevention_start_idx:].strip()
        else:
            # If keywords aren't found, return the whole response as potentially treatment info
            print("Warning: Could not reliably split Groq response into Treatment/Prevention sections.")
            treatment_info = response_content # Assign full response, user needs to read
            prevention_info = "(See above response)"


        # Add a small delay to avoid hitting rate limits if making many calls quickly
        # time.sleep(1)

        return treatment_info, prevention_info

    except ValueError as ve: # Specifically catch the API key error
        print(f"Configuration Error: {ve}")
        return "Error: API key not configured.", "Please set the GROQ_API_KEY environment variable."
    except Exception as e:
        print(f"An error occurred while querying Groq: {e}")
        import traceback
        traceback.print_exc()
        return f"Error querying Groq: {e}", ""


def predict_disease_and_get_info(symptom_description):
    """Loads SetFit model, predicts disease, then queries Groq for info."""
    print("\n--- Initiating Prediction ---")
    # Check if SetFit model files exist
    if not os.path.exists(model_save_path) or not os.path.exists(label_mapping_save_path):
        print("Error: SetFit model or label mapping file not found.")
        print(f"Please ensure '{model_save_path}' and '{label_mapping_save_path}' exist.")
        print("You may need to run the script once to train and save the classifier.")
        return "Error", "Classifier model files not found.", "Classifier model files not found."

    try:
        # 1. Load the SetFit classifier and label mapping
        print(f"Loading SetFit classifier from: {model_save_path}")
        model = SetFitModel.from_pretrained(model_save_path)

        print(f"Loading label mapping from: {label_mapping_save_path}")
        label_mapping_int_to_name = joblib.load(label_mapping_save_path)

        # 2. Predict disease using SetFit
        print(f"Classifying disease based on input: '{symptom_description[:100]}...'")
        pred_label_int = model.predict([symptom_description])[0]
        pred_label_int = int(pred_label_int.item() if hasattr(pred_label_int, 'item') else pred_label_int)
        predicted_disease_name = label_mapping_int_to_name.get(pred_label_int, "Unknown Disease")
        print(f"SetFit Prediction: {predicted_disease_name}")

        # 3. Get Treatment/Prevention info from Groq LLM
        treatment_info = "Not available."
        prevention_info = "Not available."
        if predicted_disease_name != "Unknown Disease":
            treatment_info, prevention_info = get_info_from_groq(predicted_disease_name, symptom_description)
        else:
            treatment_info = "Cannot fetch info for 'Unknown Disease'."
            prevention_info = "Cannot fetch info for 'Unknown Disease'."

        return predicted_disease_name, treatment_info, prevention_info

    except Exception as e:
        print(f"An error occurred during prediction pipeline: {e}")
        import traceback
        traceback.print_exc()
        return "Error", f"Prediction pipeline failed: {e}", ""


In [None]:
# --- Main Execution Logic ---
if __name__ == "__main__":
    # --- Optional Training Phase ---
    # You can comment this section out if you have already trained and saved the model
    # Or add a check if model files exist, and skip training if they do.
    run_training = True # Set to False to skip training if model is saved
    if os.path.exists(model_save_path) and os.path.exists(label_mapping_save_path):
        print("Found existing SetFit model and mapping. Skipping training.")
        print(f"Model: {model_save_path}")
        print(f"Mapping: {label_mapping_save_path}")
        run_training = False
        # You might want to add a command-line argument to force retraining, e.g. `python script.py --train`

    if run_training:
        try:
            # 1. Load and Prepare Data for SetFit
            full_dataset, label_mapping = load_and_prepare_data_for_setfit(file_path)

            # 2. Split data
            print(f"Splitting data into Train ({1-TEST_SET_SIZE:.0%}) / Eval ({TEST_SET_SIZE:.0%})...")
            labels_for_split = full_dataset['label']
            unique_labels, counts = np.unique(labels_for_split, return_counts=True)
            min_samples_per_class = counts.min() if len(counts) > 0 else 0

            eval_dataset = None
            if TEST_SET_SIZE > 0 and len(full_dataset) > 1:
                if min_samples_per_class < 2 :
                    print(f"Warning: Some classes have only {min_samples_per_class} sample. Cannot stratify split reliably.")
                    # Decide: Use non-stratified or use all data? Using non-stratified.
                    print("Using non-stratified split.")
                    split_datasets = full_dataset.train_test_split(test_size=TEST_SET_SIZE, seed=42) # No stratify
                    train_dataset = split_datasets['train']
                    eval_dataset = split_datasets['test']
                else:
                    # Proceed with stratified split
                    split_datasets = full_dataset.train_test_split(test_size=TEST_SET_SIZE, seed=42, stratify_by_column='label')
                    train_dataset = split_datasets['train']
                    eval_dataset = split_datasets['test']
                print(f"Training set size: {len(train_dataset)}")
                print(f"Evaluation set size: {len(eval_dataset)}")
            else:
                print("Using full dataset for training. No evaluation split created.")
                train_dataset = full_dataset
                # eval_dataset remains None

            # 3. Train the SetFit Classifier
            eval_ds_for_train = eval_dataset if eval_dataset else train_dataset # Use train if no eval split
            trained_model, training_metrics = train_disease_classifier(train_dataset, eval_ds_for_train, label_mapping)

            # 4. Save the results (Model and Label Mapping only)
            print("\n--- Saving Training Artifacts ---")
            print(f"Saving SetFit model to '{model_save_path}'...")
            trained_model.save_pretrained(model_save_path)
            print(f"Saving label mapping to '{label_mapping_save_path}'...")
            joblib.dump(label_mapping, label_mapping_save_path)
            print("SetFit classifier model and mapping saved successfully.")

        except (FileNotFoundError, ValueError, ImportError) as e:
            print(f"\n--- Setup/Training Error ---")
            print(e)
            print("Please check file path, Excel columns, required libraries, and GROQ_API_KEY environment variable.")
            exit() # Stop if training fails
        except Exception as e:
            print(f"\n--- An Unexpected Error Occurred During Setup/Training ---")
            print(e)
            import traceback
            traceback.print_exc()
            exit() # Stop if training fails


    # --- Interactive Prediction Phase ---
    print("\n--- Interactive Disease Prediction & Info Retrieval ---")
    # Check for Groq API key availability before starting interactive part
    if not os.environ.get("GROQ_API_KEY"):
         print("\n*** WARNING: GROQ_API_KEY environment variable not detected. ***")
         print("   Treatment/Prevention info retrieval via Groq will fail.")
         print("   Please set the environment variable and restart the script.")

    print("\nEnter symptoms or description. Type 'quit' to exit.")

    while True:
        user_input = input("\nSymptoms: ")
        if user_input.lower() == 'quit':
            break
        if not user_input.strip():
            print("Please enter some text.")
            continue

        predicted_disease, treatment, prevention = predict_disease_and_get_info(user_input)

        print("\n--- Prediction & Information ---")
        print(f"Predicted Disease (using SetFit + local data): {predicted_disease}")
        print("\n[!] Disclaimer: The predicted disease is based on the fine-tuned model. Both the prediction and the generated information below are NOT substitutes for professional medical advice. Always consult a qualified healthcare professional.")
        print(f"\nTreatment Info (Generated by AI via Groq - General Info):\n{treatment}")
        print(f"\nPrevention Info (Generated by AI via Groq - General Info):\n{prevention}")
        print("--------------------------------")

In [None]:
import os
from groq import Groq

# Set your Groq model and API key
GROQ_MODEL_NAME = "meta-llama/llama-4-maverick-17b-128e-instruct"
API_KEY = os.getenv("GROQ_API_KEY")

# Create Groq client
client = Groq(api_key=API_KEY)

# Initialize system context for a smart medical chatbot
system_message = {
    "role": "system",
    "content": (
        "You are an intelligent, caring AI medical assistant named MedBot. "
        "You are integrated into a diagnostic system that uses symptom inputs and medical models to help users understand their health. "
        "Be conversational, empathetic, and helpful. Answer symptom-related questions, guide users through general medical concerns, and explain what-if scenarios. "
        "Also assist in interpreting results from an AI-powered system with disease predictions, wearable data, or scenario simulations. "
        "Always remind users that you are not a substitute for professional medical advice."
    )
}

# Chat history
history = [system_message]

print("\n🩺 Welcome to MedBot – Your AI Medical Chat Assistant (Groq Powered)")
print("Type 'exit' to quit. Ask about symptoms, health concerns, or anything related to diagnosis, treatment, or wellness.\n")

# Chat loop
while True:
    user_input = input("You: ")
    if user_input.lower() in ["exit", "quit"]:
        print("MedBot: Stay healthy! Remember to consult a doctor for medical decisions.")
        break

    # Add user input to chat history
    history.append({"role": "user", "content": user_input})

    # Send conversation to Groq
    try:
        response = client.chat.completions.create(
            model=GROQ_MODEL_NAME,
            messages=history
        )
        bot_reply = response.choices[0].message.content
        history.append({"role": "assistant", "content": bot_reply})
        print(f"\nMedBot: {bot_reply}\n")
    except Exception as e:
        print(f"\nError: {e}\n")
