In [None]:
import os

os.environ['GROQ_API_KEY'] = ''
os.environ['TAVILY_API_KEY'] = ''

print("API keys have been set in the environment.")

API keys have been set in the environment.


In [None]:
pip install transformers datasets torch scikit-learn faiss-cpu sentence-transformers groq tavily-python pandas



In [None]:
import os
import pandas as pd
from groq import Groq

try:
    client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
    print("Groq client initialized successfully.")
except Exception as e:
    print(f"Error initializing Groq client: {e}. Ensure GROQ_API_KEY is set.")
    exit()

def generate_synthetic_data(category, num_samples=10):
    """Generates synthetic medical claims for a given category using Groq."""
    prompt = f"""
    You are an expert in medical misinformation. Your task is to generate {num_samples} distinct medical claims that fall under the category: '{category}'.

    - For 'Myth', create plausible-sounding but false medical statements.
    - For 'Not a Myth', create true, evidence-based medical statements.
    - For 'Inconclusive', create claims that are currently debated, lack sufficient evidence, or depend heavily on individual context.

    Please provide the output as a numbered list of claims. Do not add any other text, just the list.

    Category: {category}
    """
    print(f"Generating {num_samples} samples for category: {category}...")
    try:
        chat_completion = client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model="llama3-70b-8192",
            temperature=0.9,
            max_tokens=1024,
            n=1,
        )
        generated_text = chat_completion.choices[0].message.content
        claims = [line.split('. ', 1)[1] for line in generated_text.strip().split('\n') if '. ' in line]
        print(f"Successfully generated {len(claims)} claims for {category}.")
        return claims
    except Exception as e:
        print(f"An error occurred while calling the Groq API: {e}")
        return []


all_data = []

categories = {"Myth": 100, "Not a Myth": 100, "Inconclusive":80}

for category, count in categories.items():
    generated_claims = generate_synthetic_data(category, count)
    for claim in generated_claims:
        all_data.append({"text": claim, "label": category})

df = pd.DataFrame(all_data)
df.to_csv("augmented_medical_claims.csv", index=False)

print("\nData generation complete. Saved to augmented_medical_claims.csv")
print(f"Total samples generated: {len(df)}")

Groq client initialized successfully.
Generating 100 samples for category: Myth...
Successfully generated 79 claims for Myth.
Generating 100 samples for category: Not a Myth...
Successfully generated 60 claims for Not a Myth.
Generating 80 samples for category: Inconclusive...
Successfully generated 61 claims for Inconclusive.

Data generation complete. Saved to augmented_medical_claims.csv
Total samples generated: 200


In [None]:
from sklearn.model_selection import train_test_split
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support



df = pd.read_csv("augmented_medical_claims.csv")

labels = df['label'].unique().tolist()
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}
df['label'] = df['label'].map(label2id)



train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])

train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id
)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)


train_dataset = train_dataset.remove_columns(["__index_level_0__"])
test_dataset = test_dataset.remove_columns(["__index_level_0__"])


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall}

training_args = TrainingArguments(
    output_dir="./mythbert_classifier_output",
    num_train_epochs=6,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

print("Starting model training...")
trainer.train()

trainer.save_model("./fine-tuned-mythbert")
tokenizer.save_pretrained("./fine-tuned-mythbert")

print("\nTraining complete. Model saved to './fine-tuned-mythbert'")

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/225k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/160 [00:00<?, ? examples/s]

Map:   0%|          | 0/40 [00:00<?, ? examples/s]

Starting model training...




<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mkousjikshaw1111[0m ([33mkousjikshaw1111-kiit-deemed-to-be-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,1.1522,1.08117,0.275,0.219286,0.208333,0.275
2,0.9772,0.883467,0.8,0.792823,0.804848,0.8
3,0.7305,0.623482,0.9,0.901097,0.903846,0.9
4,0.4296,0.363617,0.925,0.926097,0.928846,0.925
5,0.1952,0.211424,0.925,0.926097,0.928846,0.925
6,0.1038,0.114626,0.95,0.950976,0.957143,0.95



Training complete. Model saved to './fine-tuned-mythbert'


In [None]:
pip install PyPDF2==3.0.1



In [None]:
import os
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np
import PyPDF2
import pandas as pd


knowledge_base_path = "knowledge_base"
if not os.path.exists(knowledge_base_path):
    os.makedirs(knowledge_base_path)
    with open(os.path.join(knowledge_base_path, "who_facts.txt"), "w") as f:
        f.write("The World Health Organization states that vaccines are safe and effective. They prevent many infectious diseases.")
    with open(os.path.join(knowledge_base_path, "cdc_facts.txt"), "w") as f:
        f.write("The CDC confirms that there is no link between vaccines and autism. Childhood immunizations are crucial for public health.")

    print(f"Please place a dummy PDF file (e.g., 'sample.pdf') in the '{knowledge_base_path}' directory to test PDF loading.")

    dummy_csv_content = """header1,header2,header3
    Vaccine safety,Rigorous testing,Approved by FDA
    Autism link,Debunked,No scientific evidence
    Immunization,Community health,Disease prevention
    """
    with open(os.path.join(knowledge_base_path, "medical_data.csv"), "w") as f:
        f.write(dummy_csv_content)
    print(f"Created dummy 'medical_data.csv' in '{knowledge_base_path}'.")


def load_documents_from_directory(directory_path):
    """
    Loads text content from .txt, .pdf, and .csv files within a specified directory.
    For CSV files, it concatenates content from all columns into a single string per row,
    prepending each cell's content with its column header for better context.
    """
    documents = []
    filepaths = []
    for filename in os.listdir(directory_path):
        filepath = os.path.join(directory_path, filename)
        content = ""

        if filename.endswith(".txt"):
            try:
                with open(filepath, 'r', encoding='utf-8') as f:
                    content = f.read()
            except Exception as e:
                print(f"Error reading TXT file {filename}: {e}")

        elif filename.endswith(".pdf"):
            try:
                with open(filepath, 'rb') as f:
                    reader = PyPDF2.PdfReader(f)
                    for page_num in range(len(reader.pages)):
                        page = reader.pages[page_num]
                        content += page.extract_text() + "\n"
            except Exception as e:
                print(f"Error reading PDF file {filename}: {e}")
                content = ""

        elif filename.endswith(".csv"):
            try:
                df = pd.read_csv(filepath)
                for index, row in df.iterrows():
                    row_content_parts = []
                    for col_name, cell_value in row.items():
                        row_content_parts.append(f"{col_name}: {cell_value}")
                    row_content = " ".join(row_content_parts)

                    documents.append(row_content)
                    filepaths.append(f"{filepath} (row {index})")


                continue

            except Exception as e:
                print(f"Error reading CSV file {filename}: {e}")
                content = ""

        if content:
            documents.append(content.strip())
            filepaths.append(filepath)

    print(f"Loaded {len(documents)} textual segments/documents from '{directory_path}'.")
    return documents, filepaths

embedding_model = SentenceTransformer('BAAI/bge-small-en-v1.5')

def create_faiss_index(docs, model):
    print("Encoding documents for FAISS index...")
    valid_docs = [doc for doc in docs if isinstance(doc, str) and doc.strip()]
    if not valid_docs:
        print("No valid documents to encode. FAISS index will be empty.")
        return faiss.IndexFlatL2(model.get_sentence_embedding_dimension()), []

    embeddings = model.encode(valid_docs, convert_to_tensor=False)
    embedding_dim = embeddings.shape[1]

    index = faiss.IndexFlatL2(embedding_dim)
    index = faiss.IndexIDMap(index)

    ids = np.array(range(len(valid_docs)))
    index.add_with_ids(embeddings, ids)

    print(f"FAISS index created successfully with {index.ntotal} vectors.")
    return index, valid_docs

local_documents, doc_filepaths = load_documents_from_directory(knowledge_base_path)

if local_documents:
    faiss_index, indexed_documents = create_faiss_index(local_documents, embedding_model)
    if faiss_index.ntotal > 0:
        faiss.write_index(faiss_index, "medical_kb.index")

        with open("doc_map.txt", "w", encoding='utf-8') as f:
            for i, doc_content in enumerate(indexed_documents):

                f.write(f"{i}: {doc_filepaths[i]}\n")
        print("FAISS index saved to 'medical_kb.index' and mapping to 'doc_map.txt'.")
    else:
        print("FAISS index could not be created as there were no valid documents after encoding.")
else:
    print("No documents found or successfully loaded from the knowledge base directory. FAISS index not created.")


Loaded 5 textual segments/documents from 'knowledge_base'.
Encoding documents for FAISS index...
FAISS index created successfully with 5 vectors.
FAISS index saved to 'medical_kb.index' and mapping to 'doc_map.txt'.


In [None]:

import os
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from groq import Groq
from tavily import TavilyClient
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer


print("Initializing components...")

groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))

tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY"))

embedding_model = SentenceTransformer('BAAI/bge-small-en-v1.5')


try:
    faiss_index = faiss.read_index("medical_kb.index")
    with open("doc_map.txt", "r", encoding='utf-8') as f:

        doc_filepaths = [line.strip() for line in f.readlines()]
    print("Existing FAISS index and document map loaded.")
except Exception as e:
    print(f"Could not load existing FAISS index or doc map: {e}")
    faiss_index = None
    doc_filepaths = []
    print("Starting with an empty FAISS index and document list.")



model_path = "./fine-tuned-mythbert"
print(f"Loading fine-tuned model from {model_path}...")
try:
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model directory not found at {model_path}")
    if not os.path.exists(os.path.join(model_path, "pytorch_model.bin")) and not os.path.exists(os.path.join(model_path, "model.safetensors")):
         raise FileNotFoundError(f"Model weights not found in {model_path}")
    if not os.path.exists(os.path.join(model_path, "tokenizer_config.json")):
        raise FileNotFoundError(f"Tokenizer config not found in {model_path}")

    classifier_tokenizer = AutoTokenizer.from_pretrained(model_path)
    classifier_model = AutoModelForSequenceClassification.from_pretrained(model_path)
    print("Fine-tuned model and tokenizer loaded successfully.")
except Exception as e:
    print(f"Error loading fine-tuned model: {e}")
    classifier_tokenizer = None
    classifier_model = None



def classify_claim(claim):
    """Classifies a medical claim using the fine-tuned model."""
    if classifier_model is None or classifier_tokenizer is None:
        print("Classifier not loaded. Cannot classify claim.")
        return "Unknown Classification"

    try:
        inputs = classifier_tokenizer(claim, return_tensors="pt", padding="max_length", truncation=True, max_length=128)

        with torch.no_grad():
            outputs = classifier_model(**inputs)

        logits = outputs.logits
        predicted_class_id = logits.argmax().item()
        predicted_label = classifier_model.config.id2label[predicted_class_id]

        return predicted_label
    except Exception as e:
        print(f"Error during classification: {e}")
        return "Classification Error"


def retrieve_from_faiss(query, k=3):
    """Retrieves top-k documents from the FAISS index."""
    if faiss_index is None or faiss_index.ntotal == 0:
        print("FAISS index is empty or not loaded. Skipping FAISS retrieval.")
        return "No information found in knowledge base."

    query_embedding = embedding_model.encode([query], convert_to_tensor=True)
    query_embedding = query_embedding.cpu().numpy().astype('float32')

    distances, indices = faiss_index.search(query_embedding, k)
    retrieved_docs_content = []
    for i in indices[0]:
        if i != -1 and i < len(doc_filepaths):
            file_path_entry = doc_filepaths[i]


            retrieved_docs_content.append(f"Source: {file_path_entry}")


    return "\\n---\\n".join(retrieved_docs_content)


def retrieve_from_web(query, max_results=3):
    """Retrieves context from Tavily web search and returns raw results."""
    try:
        response = tavily_client.search(query=query, search_depth="basic", max_results=max_results)
        if response and 'results' in response and response['results']:

             return response['results']
        else:
             return []
    except Exception as e:
        print(f"Web search failed: {e}")
        return []



def add_web_results_to_kb(web_results):
    """Adds content from web search results to the FAISS index and doc map."""
    global faiss_index, doc_filepaths, embedding_model

    if not web_results:
        print("No web results to add to knowledge base.")
        return

    new_documents = []
    new_filepaths = []

    for result in web_results:
        url = result.get('url', 'N/A')
        content = result.get('content', '')
        if content:
            identifier = f"Web Source: {url}"
            new_documents.append(content)
            new_filepaths.append(identifier)

    if not new_documents:
        print("\nNo valid content found in web results to add.\n")
        return

    print(f"Adding {len(new_documents)} web documents to the knowledge base...")


    new_embeddings = embedding_model.encode(new_documents, convert_to_tensor=False)

    if faiss_index is None or faiss_index.ntotal == 0:
        embedding_dim = new_embeddings.shape[1]
        faiss_index = faiss.IndexFlatL2(embedding_dim)
        faiss_index = faiss.IndexIDMap(faiss_index)
        current_id = 0
        print("Initialized a new FAISS index.")
    else:
        current_id = faiss_index.ntotal


    new_ids = np.array(range(current_id, current_id + len(new_embeddings)))

    faiss_index.add_with_ids(new_embeddings, new_ids)

    doc_filepaths.extend(new_filepaths)

    try:
        faiss.write_index(faiss_index, "medical_kb.index")
        with open("doc_map.txt", "w", encoding='utf-8') as f:
            for i, filepath in enumerate(doc_filepaths):
                f.write(f"{i}: {filepath}\n")
        print(f"\nFAISS index updated with {len(new_documents)} web documents. Total vectors: {faiss_index.ntotal}\n")
    except Exception as e:
        print(f"Error saving updated FAISS index or doc map: {e}")


def generate_explanation(claim, classification):
    """Generates a final explanation using the RAG pipeline."""
    print(f"\nGenerating explanation for claim: '{claim}' classified as '{classification}'")

    faiss_context = retrieve_from_faiss(claim)

    combined_context = f"""
    CONTEXT FROM KNOWLEDGE BASE:
    {faiss_context}
    """

    prompt = f"""
    You are a precise medical fact-checking assistant.
    A medical claim has been classified as: **{classification}**.
    The claim is: **"{claim}"**

    Using the provided context from the knowledge base below, generate a concise and medically accurate explanation.
    - Directly address the claim's validity based on the classification.
    - Synthesize information from the knowledge base.
    - If citing a web source, refer to the 'Web Source: [URL]' identifier provided in the context.
    - Maintain a neutral, factual, and easy-to-understand tone.
    - Do not use any markdown formatting in your response.

    {combined_context}

    Final Explanation:
    """

    try:
        chat_completion = groq_client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model="llama3-70b-8192",
            temperature=0.2,
        )
        return chat_completion.choices[0].message.content
    except Exception as e:
        return f"Error during generation: {e}"


input_claim ="BRAIN EATING BACTERIA PRESENT IN STILL WATER"

claim_classification = classify_claim(input_claim)
print(f"\nClassified claim as: {claim_classification}\n")


web_results = retrieve_from_web(input_claim)

add_web_results_to_kb(web_results)

final_explanation = generate_explanation(input_claim, claim_classification)

print("\n"+"-"*50)
print("FINAL GENERATED EXPLANATION")
print("-"*50)
print(final_explanation)

Initializing components...
Existing FAISS index and document map loaded.
Loading fine-tuned model from ./fine-tuned-mythbert...
Fine-tuned model and tokenizer loaded successfully.

Classified claim as: Myth

Adding 3 web documents to the knowledge base...

FAISS index updated with 3 web documents. Total vectors: 11


Generating explanation for claim: 'BRAIN EATING BACTERIA PRESENT IN STILL WATER' classified as 'Myth'

--------------------------------------------------
FINAL GENERATED EXPLANATION
--------------------------------------------------
The claim "BRAIN EATING BACTERIA PRESENT IN STILL WATER" is a myth. 

The myth likely originated from the fact that a deadly brain-eating amoeba called Naegleria fowleri can be found in freshwater lakes, rivers, and hot springs. However, it is not a bacterium, but rather an amoeba. According to the Cleveland Clinic, Naegleria fowleri is a type of free-living amoeba that can be found in warm, freshwater environments. 

It's also important to not

In [None]:
import pandas as pd

hyperparameters = {
    "Parameter": [
        "Model Name",
        "Num Labels",
        "Max Token Length",
        "Num Train Epochs",
        "Per Device Train Batch Size",
        "Per Device Eval Batch Size",
        "Warmup Steps",
        "Weight Decay",
        "Logging Steps",
        "Evaluation Strategy",
        "Save Strategy",
        "Load Best Model At End"
    ],
    "Value": [
        "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
        len(labels),  # Assuming 'labels' is defined in a previous cell
        128,
        training_args.num_train_epochs,
        training_args.per_device_train_batch_size,
        training_args.per_device_eval_batch_size,
        training_args.warmup_steps,
        training_args.weight_decay,
        training_args.logging_steps,
        training_args.eval_strategy,
        training_args.save_strategy,
        training_args.load_best_model_at_end
    ]
}

hyperparameter_table = pd.DataFrame(hyperparameters)
print("Hyperparameters Used:")
print(hyperparameter_table.to_string(index=False))

Hyperparameters Used:
                  Parameter                                                Value
                 Model Name microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract
                 Num Labels                                                    3
           Max Token Length                                                  128
           Num Train Epochs                                                    6
Per Device Train Batch Size                                                    8
 Per Device Eval Batch Size                                                    8
               Warmup Steps                                                  500
               Weight Decay                                                 0.01
              Logging Steps                                                   10
        Evaluation Strategy                               IntervalStrategy.EPOCH
              Save Strategy                                   SaveStrategy.EPOCH
     L