# PubMed Medical Literature Analysis with Gradio

<table style="float: left; margin-right: 20px;">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Run in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

## 🚀 Quick Start

1. Click **Runtime → Run all** (or press Ctrl/Cmd + F9)
2. Wait for the app to launch (~30 seconds)
3. Click the Gradio link that appears

This notebook analyzes medical cases using PubMed literature with BigQuery vector search and Gemini AI.

In [None]:
# @title 1️⃣ Install Dependencies { display-mode: "form" }
!pip install gradio google-genai google-cloud-bigquery pandas plotly -q
print("✅ Dependencies installed")

In [None]:
# @title 2️⃣ Authenticate with Google Cloud { display-mode: "form" }
import sys
if "google.colab" in sys.modules:
    from google.colab import auth
    auth.authenticate_user()
    print("✅ Authenticated with Google Cloud")
else:
    print("ℹ️ Running locally - using default credentials")

In [None]:
# @title 3️⃣ Configuration { display-mode: "form" }
import os

# User configuration
PROJECT_ID = "" # @param {type:"string"}
if not PROJECT_ID:
    PROJECT_ID = os.environ.get("GOOGLE_CLOUD_PROJECT", "")

LOCATION = "us-central1" # @param ["us-central1", "us-east1", "us-west1", "europe-west1", "asia-northeast1"] {type:"string"}
USER_DATASET = "pubmed" # @param {type:"string"}

# Constants
PUBMED_DATASET = "wz-data-catalog-demo.pubmed"
PUBMED_TABLE = f"{PUBMED_DATASET}.pmid_embed_nonzero_metadata"
EMBEDDING_MODEL = f"{PROJECT_ID}.{USER_DATASET}.textembed"
MODEL_ID = "gemini-2.5-flash"
JOURNAL_IMPACT_CSV_URL = "https://raw.githubusercontent.com/WandLZhang/scimagojr_2024/main/scimagojr_2024.csv"

# Sample case for demo
SAMPLE_CASE = """A 4-year-old male presents with a 3-week history of progressive fatigue, pallor, and easy bruising. \nPhysical examination reveals hepatosplenomegaly and scattered petechiae. \n\nLaboratory findings:\n- WBC: 45,000/μL with 80% blasts\n- Hemoglobin: 7.2 g/dL\n- Platelets: 32,000/μL\n\nFlow cytometry: CD33+, CD13+, CD117+, CD34+, HLA-DR+, CD19-, CD3-\n\nCytogenetics: 46,XY,t(9;11)(p21.3;q23.3)\nMolecular: KMT2A-MLLT3 fusion detected, FLT3-ITD positive, NRAS G12D mutation\n\nDiagnosis: KMT2A-rearranged acute myeloid leukemia (AML)"""

print(f"📍 Project: {PROJECT_ID}")
print(f"📍 Location: {LOCATION}")
print(f"📍 Dataset: {USER_DATASET}")

In [None]:
# @title 4️⃣ Launch PubMed Analysis App { display-mode: "form" }

import gradio as gr
import pandas as pd
import json
import math
from datetime import datetime
import plotly.express as px
import plotly.graph_objects as go
from google import genai
from google.cloud import bigquery
from google.genai.types import GenerateContentConfig

# --- Global Variables ---
genai_client, bq_client = None, None
journal_impact_dict = {}

# --- Core Functions ---
def init_clients(project_id, location):
    try:
        genai_client = genai.Client(vertexai=True, project=project_id, location=location)
        bq_client = bigquery.Client(project=project_id)
        return genai_client, bq_client
    except Exception as e:
        return None, None

def load_journal_data():
    try:
        df = pd.read_csv(JOURNAL_IMPACT_CSV_URL, sep=';')
        return dict(zip(df['Title'], df['SJR']))
    except:
        return {}

def extract_medical_info(case_text, client):
    prompts = {
        "disease": "Extract the primary disease diagnosis. Return ONLY the name.",
        "events": "Extract all actionable medical events (mutations, biomarkers, etc.). Return a comma-separated list."
    }
    results = {}
    for key, prompt in prompts.items():
        full_prompt = f"{prompt}\n\nCase notes:\n{case_text}"
        response = client.models.generate_content(model=MODEL_ID, contents=[full_prompt], config=GenerateContentConfig(temperature=0))
        results[key] = response.text.strip()
    return results

def search_pubmed_articles(disease, events, bq_client, embedding_model, pubmed_table, top_k):
    query_text = f"{disease} {' '.join(events)}"
    sql = f"""SELECT base.PMID, base.content, base.text as abstract, distance FROM VECTOR_SEARCH(TABLE `{pubmed_table}`, 'ml_generate_embedding_result', (SELECT ml_generate_embedding_result FROM ML.GENERATE_EMBEDDING(MODEL `{embedding_model}`, (SELECT \"{query_text}\" AS content))), top_k => {top_k})"""
    return bq_client.query(sql).to_dataframe()

def analyze_article_batch(df, disease, events, client, journal_dict):
    journal_context = "\n".join([f"- {title}: {sjr}" for title, sjr in journal_dict.items()])
    prompt = f"""Analyze articles for relevance to Disease: {disease} and Events: {', '.join(events)}. Use this data: {journal_context}. For each, extract: title, journal_title, journal_sjr, year, disease_match (bool), pediatric_focus (bool), treatment_shown (bool), paper_type, key_findings, clinical_trial (bool), novel_findings (bool). Return JSON array."""
    articles_text = ""
    for _, row in df.iterrows():
        content = row.get('content', row.get('abstract', ''))
        articles_text += f"\n---\nPMID: {row['PMID']}\nContent: {content}\n"
    response = client.models.generate_content(model=MODEL_ID, contents=[prompt + articles_text], config=GenerateContentConfig(temperature=0, response_mime_type="application/json"))
    try:
        return json.loads(response.text)
    except json.JSONDecodeError:
        return []

def calculate_article_score(metadata, config):
    score = 0
    # Simplified scoring logic for brevity
    if metadata.get('disease_match'): score += config.get('disease_match', 50)
    if metadata.get('treatment_shown'): score += config.get('treatment_efficacy', 50)
    if metadata.get('clinical_trial'): score += config.get('clinical_trial', 40)
    return round(score, 2)

def setup_bigquery(project, dataset, location, client):
    try:
        client.get_dataset(f"{project}.{dataset}")
    except:
        client.create_dataset(bigquery.Dataset(f"{project}.{dataset}"), exists_ok=True)
    model_query = f"CREATE MODEL IF NOT EXISTS `{project}.{dataset}.textembed` REMOTE WITH CONNECTION DEFAULT OPTIONS(endpoint='text-embedding-005');"
    client.query(model_query).result()
    return f"✅ BigQuery setup complete for {project}.{dataset}"

# --- Gradio App Logic ---
def validate_project_and_setup(project_id, dataset_name):
    global genai_client, bq_client, journal_impact_dict, PROJECT_ID, USER_DATASET
    if not project_id or not dataset_name:
        return "❌ Project ID and Dataset Name are required.", gr.update(interactive=False)
    PROJECT_ID = project_id
    USER_DATASET = dataset_name
    genai_client, bq_client = init_clients(project_id, LOCATION)
    if not genai_client or not bq_client:
        return "❌ Failed to initialize clients. Check permissions.", gr.update(interactive=False)
    journal_impact_dict = load_journal_data()
    setup_message = setup_bigquery(project_id, dataset_name, LOCATION, bq_client)
    return f"✅ Project validated. Loaded {len(journal_impact_dict)} journals.\n{setup_message}", gr.update(interactive=True)

def run_analysis(case_text, num_articles, progress=gr.Progress()):
    if not genai_client or not bq_client:
        return None, "❌ Please validate project first.", {}
    progress(0.1, desc="Extracting medical info...")
    medical_info = extract_medical_info(case_text, genai_client)
    disease = medical_info.get('disease', '')
    events = [e.strip() for e in medical_info.get('events', '').split(',')]
    
    progress(0.3, desc="Searching PubMed...")
    embedding_model_path = f"{PROJECT_ID}.{USER_DATASET}.textembed"
    articles_df = search_pubmed_articles(disease, events, bq_client, embedding_model_path, PUBMED_TABLE, num_articles)
    
    progress(0.6, desc="Analyzing articles...")
    analyses = analyze_article_batch(articles_df, disease, events, genai_client, journal_impact_dict)
    
    # Merge and score
    for i, analysis in enumerate(analyses):
        for k, v in analysis.items():
            articles_df.loc[i, k] = v
    
    scoring_config = SCORING_PRESETS["Clinical Focus"] # Simplified for this version
    articles_df['score'] = articles_df.apply(lambda row: calculate_article_score(row, scoring_config), axis=1)
    articles_df = articles_df.sort_values('score', ascending=False).reset_index()
    
    progress(0.9, desc="Generating results...")
    results_table = articles_df[['score', 'title', 'journal_title', 'year']].head(10)
    results = {'articles': articles_df.to_dict('records'), 'disease': disease, 'events': events, 'case_text': case_text}
    return results_table, f"✅ Analysis complete for '{disease}'.", results

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🏥 PubMed Medical Literature Analysis")
    app_state = gr.State({})

    with gr.Tabs():
        with gr.TabItem("1. Setup"):
            project_id_input = gr.Textbox(label="Google Cloud Project ID", value=PROJECT_ID)
            dataset_input = gr.Textbox(label="BigQuery Dataset Name", value=USER_DATASET)
            setup_btn = gr.Button("Validate & Setup Project", variant="primary")
            setup_status = gr.Markdown()

        with gr.TabItem("2. Analyze Case"):
            case_input = gr.Textbox(label="Patient Case Notes", value=SAMPLE_CASE, lines=10)
            num_articles_slider = gr.Slider(5, 50, 10, step=1, label="Number of Articles to Analyze")
            analyze_btn = gr.Button("Run Full Analysis", variant="primary", interactive=False)
            analysis_status = gr.Markdown()

        with gr.TabItem("3. Results"):
            results_df = gr.DataFrame(label="Top 10 Ranked Articles")

    setup_btn.click(validate_project_and_setup, inputs=[project_id_input, dataset_input], outputs=[setup_status, analyze_btn])
    analyze_btn.click(run_analysis, inputs=[case_input, num_articles_slider], outputs=[results_df, analysis_status, app_state])

demo.launch(share=True, debug=True)