<a href="https://colab.research.google.com/github/TAUforPython/BioMedAI/blob/main/LLM%20Mistral%20Medical%20Claims%20Dataset%20generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install mistralai --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/440.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m440.5/440.5 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/160.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m160.3/160.3 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import json
import random
import time
from mistralai import Mistral # Import Mistral client
from tenacity import retry, stop_after_attempt, wait_exponential

In [4]:
from google.colab import userdata

# Get Hugging Face and PubMed token from environment
MISTRAL_TOKEN = userdata.get("Mistral_API")
if MISTRAL_TOKEN:
  print("✅ MISTRAL token detected")
else:
  print("⚠️  No MISTRAL TOKEN")

✅ MISTRAL token detected


In [24]:

client = Mistral(api_key=MISTRAL_TOKEN) # Initialize Mistral client

MODEL_NAME = "mistral-large-latest" # Or "open-mistral-7b", "open-mixtral-8x7b", etc.
NUM_SUPPORTED_CLAIMS = 300
NUM_REFUTED_CLAIMS = 300
MAX_RETRIES = 3
REQUEST_DELAY = 0.1 # Seconds between requests

# --- Base Prompts ---
CONTEXT_PROMPT_TEMPLATE = """
Generate a short, realistic paragraph (3-5 sentences) about a medical topic.
The paragraph should contain factual information that could be found in a medical journal or reputable health source.
Focus on topics like diet, exercise, disease prevention, treatment efficacy, or physiological effects.
Example Topics: Benefits of Omega-3s, Risks of Chronic Stress, Effectiveness of Vaccines, Impact of Sleep Deprivation.

Medical Context Paragraph:
"""

FACT_EXTRACTION_PROMPT_TEMPLATE = """
Segment the following medical text into individual, atomic facts. An atomic fact is a simple, self-contained statement of medical information.

Text: {text}

Please respond **only** with a JSON object in the following format:
{{
  "facts": [
    "Fact 1",
    "Fact 2",
    "Fact 3",
    ...
  ]
}}

JSON Facts:
"""

# --- LLM Interaction Functions (with Retries) ---

@retry(stop=stop_after_attempt(MAX_RETRIES), wait=wait_exponential(multiplier=1, min=4, max=10))
def call_llm(prompt, model=MODEL_NAME):
    """Calls the Mistral LLM with a prompt, includes retry logic."""
    chat_response = client.chat.complete(
        model=model,
        messages=[
            {"role": "system", "content": "You are a helpful assistant providing accurate medical information."},
            {"role": "user", "content": prompt}
        ],
        # Adjust parameters as needed
        max_tokens=500,
        temperature=0.7
    )
    # Mistral response structure is slightly different
    return chat_response.choices[0].message.content.strip()

# --- Synthetic Data Generation Functions ---

def generate_context():
    """Generates a realistic medical context paragraph."""
    prompt = CONTEXT_PROMPT_TEMPLATE
    context = call_llm(prompt)
    return context

def extract_facts(text):
    """Extracts atomic facts from a given text using the LLM."""
    prompt = FACT_EXTRACTION_PROMPT_TEMPLATE.format(text=text)
    try:
        response_text = call_llm(prompt)
        # Attempt to parse JSON
        # Mistral might sometimes add extra text, so we try to find the JSON part
        # A more robust solution would involve stricter parsing or asking the model not to add extra text.
        start_index = response_text.find('{')
        end_index = response_text.rfind('}') + 1
        if start_index != -1 and end_index > start_index:
             json_str = response_text[start_index:end_index]
             data = json.loads(json_str)
        else:
             # Fallback if no clear JSON object found
             data = json.loads(response_text)

        facts = data.get("facts", [])
        return [fact.strip() for fact in facts if fact.strip()]
    except (json.JSONDecodeError, KeyError) as e:
        print(f"Error parsing facts from LLM response: {e}. Response was: {response_text}")
        return [] # Return empty list on failure

def generate_supported_pair():
    """Generates one Supported claim-evidence pair."""
    context = generate_context()
    if not context:
        return None, None, None

    facts = extract_facts(context)
    if not facts:
        return None, None, None

    claim = random.choice(facts)
    evidence = context # Use the full context as evidence for supported claims
    label = 1 # Supported

    return claim, evidence, label

def generate_refuted_pair(base_refuted_claims):
    """Generates one Refuted claim-evidence pair using base examples."""
    if not base_refuted_claims:
        return None, None, None

    claim = random.choice(base_refuted_claims)
    # Simple simulated evidence for refuted claims
    base_evidence = f"Epidemiological research and systematic reviews show that {claim.lower()[:-1] if claim else 'unknown claim'} is contradicted by current medical understanding."
    evidence = base_evidence
    label = 0 # Refuted

    return claim, evidence, label

# --- Main Generation Loop ---

def generate_synthetic_dataset(num_supported, num_refuted, base_refuted_claims):
    """Main function to generate the synthetic dataset."""
    synthetic_claims = []
    synthetic_evidences = []
    synthetic_labels = []

    print(f"Generating {num_supported} Supported claim-evidence pairs...")
    generated_supported = 0
    attempts = 0
    max_attempts = num_supported * 3
    while generated_supported < num_supported and attempts < max_attempts:
        claim, evidence, label = generate_supported_pair()
        attempts += 1
        if claim and evidence and label is not None:
            if claim not in synthetic_claims: # Basic duplicate check
                synthetic_claims.append(claim)
                synthetic_evidences.append(evidence)
                synthetic_labels.append(label)
                generated_supported += 1
                print(f"  Generated Supported Pair {generated_supported}/{num_supported}")
                time.sleep(REQUEST_DELAY)
            else:
                print(f"  Duplicate claim skipped.")
        else:
            print(f"  Failed to generate a valid Supported pair (attempt {attempts}).")

    print(f"Generating {num_refuted} Refuted claim-evidence pairs...")
    generated_refuted = 0
    while generated_refuted < num_refuted:
        claim, evidence, label = generate_refuted_pair(base_refuted_claims)
        if claim and evidence and label is not None:
            if claim not in synthetic_claims:
                synthetic_claims.append(claim)
                synthetic_evidences.append(evidence)
                synthetic_labels.append(label)
                generated_refuted += 1
                print(f"  Generated Refuted Pair {generated_refuted}/{num_refuted}")
            else:
                 print(f"  Duplicate claim skipped (Refuted).")
        else:
             print(f"  Failed to generate a valid Refuted pair.")

    return synthetic_claims, synthetic_evidences, synthetic_labels

# --- Base Refuted Examples ---
base_refuted_claims_examples = [
    "Vitamin C can cure the common cold.",
    "Antibiotics treat viral infections.",
    "Probiotics cure serious bacterial infections.",
    "Eating carrots significantly improves night vision.",
    "Vaccines cause autism.",
    "Drinking exactly 8 glasses of water daily is required.",
    "MSG causes headaches in everyone.",
    "Microwave cooking destroys all nutrients.",
    "Red wine consumption is always heart-healthy.",
    "Artificial sweeteners cause cancer.",
    "Detox diets are necessary for body cleansing.",
    "Protein intake damages healthy kidneys.",
    "Crash diets are effective for long-term weight loss.",
    "Supplements replace a balanced diet.",
    "All fats are bad for health.",
    "Organic food always prevents disease.",
    "Homeopathy cures serious illnesses.",
    "Essential oils cure chronic conditions.",
    "Fasting cleanses toxins from organs.",
    "Alkaline water prevents cancer."
]


In [25]:

    print("Starting synthetic medical fact-checking dataset generation using Mistral...")
    claims, evidences, labels = generate_synthetic_dataset(
        NUM_SUPPORTED_CLAIMS,
        NUM_REFUTED_CLAIMS,
        base_refuted_claims_examples
    )

    # --- Output the results ---
    print("\n" + "="*50)
    print("GENERATED SYNTHETIC DATASET")
    print("="*50)

    print("\n# Generated Supported Claims:")
    print("base_claims_supported = [")
    supported_claims = [claims[i] for i, label in enumerate(labels) if label == 1]
    for claim in supported_claims:
         print(f'    "{claim}",')
    print("]")

    print("\n# Generated Refuted Claims:")
    print("base_claims_refuted = [")
    refuted_claims = [claims[i] for i, label in enumerate(labels) if label == 0]
    for claim in refuted_claims:
         print(f'    "{claim}",')
    print("]")

    print("\n# Generated Evidence Texts (corresponding to combined claims list):")
    print("generated_evidence_texts = [")
    for evidence in evidences:
         # Use triple quotes for potentially multi-line evidence
         print(f'    """{evidence}""",')
    print("]")

    print("\n# Generated Labels (1=Supported, 0=Refuted):")
    print("generated_labels = [")
    for label in labels:
         print(f'    {label},')
    print("]")

    print("\nDataset generation complete.")


Starting synthetic medical fact-checking dataset generation using Mistral...
Generating 300 Supported claim-evidence pairs...
  Generated Supported Pair 1/300
  Generated Supported Pair 2/300
  Generated Supported Pair 3/300
  Generated Supported Pair 4/300


KeyboardInterrupt: 

In [7]:
supported_claims


['Progressive resistance exercise (PRE) improves functional capacity (e.g., gait speed, stair-climbing) in older adults within 12–24 weeks.',
 'Resistance training may provide cognitive benefits by increasing brain-derived neurotrophic factor (BDNF) levels.',
 'The muscle adaptations from resistance training are mediated by enhanced muscle protein synthesis.',
 'A 2023 meta-analysis in *The Journal of Gerontology* examined the effects of progressive resistance exercise (PRE) in older adults.',
 'Resistance training is effective in reducing fall risk in aging adults.']

In [8]:
refuted_claims

['Essential oils cure chronic conditions.',
 'Detox diets are necessary for body cleansing.',
 'MSG causes headaches in everyone.',
 'Drinking exactly 8 glasses of water daily is required.',
 'Artificial sweeteners cause cancer.']

In [15]:
import pandas

In [19]:
supported_claims_df = pandas.DataFrame(supported_claims)
supported_claims_df.to_csv('supported_claims.csv')

In [None]:
refuted_claims_df = pandas.DataFrame(refuted_claims)
refuted_claims_df.to_csv('refuted_claims.csv')