In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# PubMed Medical Literature Analysis

<!-- [PLACEHOLDER: Update these links when notebook is finalized] -->

<table style="float: left; margin-right: 20px;">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Run in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FWandLZhang%2Fpubmed-rag%2Fmain%2FPubMed_RAG_Gradio.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/WandLZhang/pubmed-rag/main/PubMed_RAG_Gradio.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<!-- [PLACEHOLDER: Update share links when notebook is finalized] -->
<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>            


| Authors |
| --- |
| [Willis Zhang](https://github.com/WandLZhang) |
| [Stone Jiang](https://github.com/siduojiang) |

## Overview

**Blog Post: Medical Literature Analysis with PubMed, BigQuery, Gemini**

<a href="[blog-post-url-placeholder]" target="_blank">
  <img src="https://storage.googleapis.com/[placeholder-image-path]/medical-literature-blog-header.jpg" alt="Medical Literature Analysis with PubMed and Gemini" width="500">
</a>

This notebook demonstrates how to analyze medical cases using PubMed literature with BigQuery vector search and Gemini. It converts the basic user experience from the [Capricorn Medical Research Application](https://capricorn-medical-research.web.app/) into an interactive Colab notebook.

## 🚀 Quick Start

1. Click **Runtime → Run all** (or press Ctrl/Cmd + F9)
2. Wait for the app to launch (~30 seconds)
3. Click the Gradio link that appears

This notebook analyzes medical cases using PubMed literature with BigQuery vector search and Gemini AI.

In [None]:
# @title 1️⃣ Install Dependencies { display-mode: "form" }
!pip install gradio google-genai google-cloud-bigquery google-cloud-resourcemanager google-cloud-service-usage pandas plotly -q
print("✅ Dependencies installed")

In [None]:
# @title 2️⃣ Authenticate with Google Cloud { display-mode: "form" }
import sys
if "google.colab" in sys.modules:
    from google.colab import auth
    auth.authenticate_user()
    print("✅ Authenticated with Google Cloud")
else:
    print("ℹ️ Running locally - using default credentials")

In [None]:
# @title 3️⃣ Configuration { display-mode: "form" }
import os

# Constants - no user configuration needed here anymore
PUBMED_DATASET = "wz-data-catalog-demo.pubmed"
PUBMED_TABLE = f"{PUBMED_DATASET}.pmid_embed_nonzero_metadata"
MODEL_ID = "gemini-2.5-flash"
JOURNAL_IMPACT_CSV_URL = "https://raw.githubusercontent.com/WandLZhang/scimagojr_2024/main/scimagojr_2024.csv"

# Available locations for Vertex AI
VERTEX_AI_LOCATIONS = [
    "us-central1",
    "us-east1",
    "us-west1",
    "europe-west1",
    "asia-northeast1"
]

# Required APIs
REQUIRED_APIS = [
    "aiplatform.googleapis.com",
    "bigquery.googleapis.com",
    "cloudresourcemanager.googleapis.com"
]

# Sample case for demo
SAMPLE_CASE = """A 4-year-old male presents with a 3-week history of progressive fatigue, pallor, and easy bruising. \n
Physical examination reveals hepatosplenomegaly and scattered petechiae. \n
\n
Laboratory findings:\n
- WBC: 45,000/μL with 80% blasts\n
- Hemoglobin: 7.2 g/dL\n
- Platelets: 32,000/μL\n
\n
Flow cytometry: CD33+, CD13+, CD117+, CD34+, HLA-DR+, CD19-, CD3-\n
\n
Cytogenetics: 46,XY,t(9;11)(p21.3;q23.3)\n
Molecular: KMT2A-MLLT3 fusion detected, FLT3-ITD positive, NRAS G12D mutation\n
\n
Diagnosis: KMT2A-rearranged acute myeloid leukemia (AML)"""

print("✅ Constants loaded")

In [None]:
# @title 4️⃣ Launch PubMed Analysis App { display-mode: "form" }

import gradio as gr
import pandas as pd
import json
import math
from datetime import datetime
import plotly.express as px
import plotly.graph_objects as go
from google import genai
from google.cloud import bigquery
from google.cloud import resourcemanager_v3
from google.cloud import service_usage_v1
from google.genai.types import GenerateContentConfig
import time

# --- Global Variables ---
genai_client, bq_client = None, None
journal_impact_dict = {}
PROJECT_ID = ""
LOCATION = "us-central1"
USER_DATASET = "pubmed"

# Scoring presets
SCORING_PRESETS = {
    "Clinical Focus": {
        "disease_match": 50,
        "treatment_efficacy": 50,
        "clinical_trial": 40
    }
}

# --- Helper Functions for Enhanced Setup ---
def list_projects():
    """List all available Google Cloud projects."""
    try:
        client = resourcemanager_v3.ProjectsClient()
        projects = []
        request = resourcemanager_v3.ListProjectsRequest()
        
        for project in client.list_projects(request=request):
            if project.state == resourcemanager_v3.Project.State.ACTIVE:
                projects.append({
                    "id": project.project_id,
                    "name": project.display_name,
                    "number": project.name.split('/')[-1]
                })
        return projects
    except Exception as e:
        print(f"Error listing projects: {e}")
        return []

def check_billing_enabled(project_id):
    """Check if billing is enabled for a project."""
    try:
        # For simplicity, we'll assume billing is enabled if APIs can be enabled
        # In production, you'd use the Cloud Billing API
        return True
    except:
        return False

def list_enabled_apis(project_id):
    """List enabled APIs for a project."""
    try:
        client = service_usage_v1.ServiceUsageClient()
        request = service_usage_v1.ListServicesRequest(
            parent=f"projects/{project_id}",
            filter="state:ENABLED"
        )
        
        enabled_apis = []
        for service in client.list_services(request=request):
            api_name = service.name.split('/')[-1]
            enabled_apis.append(api_name)
        
        return enabled_apis
    except Exception as e:
        print(f"Error listing APIs: {e}")
        return []

def enable_api(project_id, api_name):
    """Enable a specific API for a project."""
    try:
        client = service_usage_v1.ServiceUsageClient()
        request = service_usage_v1.EnableServiceRequest(
            name=f"projects/{project_id}/services/{api_name}"
        )
        
        operation = client.enable_service(request=request)
        # Wait for operation to complete (simplified)
        time.sleep(5)
        return True
    except Exception as e:
        print(f"Error enabling API {api_name}: {e}")
        return False

def check_apis_status(project_id):
    """Check status of required APIs."""
    enabled_apis = list_enabled_apis(project_id)
    api_status = {}
    
    for api in REQUIRED_APIS:
        api_status[api] = api in enabled_apis
    
    return api_status

# --- Core Functions ---
def init_clients(project_id, location):
    try:
        genai_client = genai.Client(vertexai=True, project=project_id, location=location)
        bq_client = bigquery.Client(project=project_id)
        return genai_client, bq_client
    except Exception as e:
        return None, None

def load_journal_data():
    try:
        df = pd.read_csv(JOURNAL_IMPACT_CSV_URL, sep=';')
        return dict(zip(df['Title'], df['SJR']))
    except:
        return {}

def extract_medical_info(case_text, client):
    prompts = {
        "disease": "Extract the primary disease diagnosis. Return ONLY the name.",
        "events": "Extract all actionable medical events (mutations, biomarkers, etc.). Return a comma-separated list."
    }
    results = {}
    for key, prompt in prompts.items():
        full_prompt = f"{prompt}\n\nCase notes:\n{case_text}"
        response = client.models.generate_content(model=MODEL_ID, contents=[full_prompt], config=GenerateContentConfig(temperature=0))
        results[key] = response.text.strip()
    return results

def search_pubmed_articles(disease, events, bq_client, embedding_model, pubmed_table, top_k):
    query_text = f"{disease} {' '.join(events)}"
    sql = f"""SELECT base.PMID, base.content, base.text as abstract, distance FROM VECTOR_SEARCH(TABLE `{pubmed_table}`, 'ml_generate_embedding_result', (SELECT ml_generate_embedding_result FROM ML.GENERATE_EMBEDDING(MODEL `{embedding_model}`, (SELECT \"{query_text}\" AS content))), top_k => {top_k})"""
    return bq_client.query(sql).to_dataframe()

def analyze_article_batch(df, disease, events, client, journal_dict):
    journal_context = "\n".join([f"- {title}: {sjr}" for title, sjr in journal_dict.items()])
    prompt = f"""Analyze articles for relevance to Disease: {disease} and Events: {', '.join(events)}. Use this data: {journal_context}. For each, extract: title, journal_title, journal_sjr, year, disease_match (bool), pediatric_focus (bool), treatment_shown (bool), paper_type, key_findings, clinical_trial (bool), novel_findings (bool). Return JSON array."""
    articles_text = ""
    for _, row in df.iterrows():
        content = row.get('content', row.get('abstract', ''))
        articles_text += f"\n---\nPMID: {row['PMID']}\nContent: {content}\n"
    response = client.models.generate_content(model=MODEL_ID, contents=[prompt + articles_text], config=GenerateContentConfig(temperature=0, response_mime_type="application/json"))
    try:
        return json.loads(response.text)
    except json.JSONDecodeError:
        return []

def calculate_article_score(metadata, config):
    score = 0
    # Simplified scoring logic for brevity
    if metadata.get('disease_match'): score += config.get('disease_match', 50)
    if metadata.get('treatment_shown'): score += config.get('treatment_efficacy', 50)
    if metadata.get('clinical_trial'): score += config.get('clinical_trial', 40)
    return round(score, 2)

def setup_bigquery(project, dataset, location, client):
    try:
        client.get_dataset(f"{project}.{dataset}")
    except:
        client.create_dataset(bigquery.Dataset(f"{project}.{dataset}"), exists_ok=True)
    model_query = f"CREATE MODEL IF NOT EXISTS `{project}.{dataset}.textembed` REMOTE WITH CONNECTION DEFAULT OPTIONS(endpoint='text-embedding-005');"
    client.query(model_query).result()
    return f"✅ BigQuery setup complete for {project}.{dataset}"

# --- Gradio App Logic ---
def refresh_projects():
    """Refresh the list of available projects."""
    projects = list_projects()
    if not projects:
        return gr.update(choices=[], value=None), "❌ No projects found. Please ensure you have access to at least one GCP project."
    
    choices = [f"{p['name']} ({p['id']})" for p in projects]
    return gr.update(choices=choices, value=choices[0] if choices else None), f"✅ Found {len(projects)} projects"

def check_project_setup(project_selection, location, dataset_name):
    """Check project setup status."""
    if not project_selection:
        return "❌ Please select a project first.", gr.update(visible=False), gr.update(interactive=False)
    
    # Extract project ID from selection
    project_id = project_selection.split('(')[-1].rstrip(')')
    
    status_messages = []
    all_good = True
    
    # Check billing
    billing_enabled = check_billing_enabled(project_id)
    if billing_enabled:
        status_messages.append("✅ Billing is enabled")
    else:
        status_messages.append("❌ Billing is not enabled")
        all_good = False
    
    # Check APIs
    api_status = check_apis_status(project_id)
    missing_apis = [api for api, enabled in api_status.items() if not enabled]
    
    if not missing_apis:
        status_messages.append("✅ All required APIs are enabled")
    else:
        status_messages.append(f"❌ Missing APIs: {', '.join(missing_apis)}")
        all_good = False
    
    status_text = "\n".join(status_messages)
    
    if all_good:
        status_text += "\n\n✅ Project is ready! Click 'Complete Setup' to proceed."
        return status_text, gr.update(visible=False), gr.update(interactive=True)
    else:
        return status_text, gr.update(visible=True, value=missing_apis), gr.update(interactive=False)

def enable_missing_apis(project_selection, missing_apis, progress=gr.Progress()):
    """Enable missing APIs for the project."""
    if not project_selection or not missing_apis:
        return "❌ No APIs to enable."
    
    # Extract project ID from selection
    project_id = project_selection.split('(')[-1].rstrip(')')
    
    results = []
    for i, api in enumerate(missing_apis):
        progress(i / len(missing_apis), desc=f"Enabling {api}...")
        if enable_api(project_id, api):
            results.append(f"✅ Enabled {api}")
        else:
            results.append(f"❌ Failed to enable {api}")
    
    return "\n".join(results) + "\n\n🔄 Please click 'Check Project Status' again."

def complete_setup(project_selection, location, dataset_name):
    """Complete the project setup."""
    global genai_client, bq_client, journal_impact_dict, PROJECT_ID, LOCATION, USER_DATASET
    
    if not project_selection:
        return "❌ Please select a project first.", gr.update(interactive=False)
    
    # Extract project ID from selection
    PROJECT_ID = project_selection.split('(')[-1].rstrip(')')
    LOCATION = location
    USER_DATASET = dataset_name
    
    genai_client, bq_client = init_clients(PROJECT_ID, LOCATION)
    if not genai_client or not bq_client:
        return "❌ Failed to initialize clients. Please check your permissions.", gr.update(interactive=False)
    
    journal_impact_dict = load_journal_data()
    setup_message = setup_bigquery(PROJECT_ID, USER_DATASET, LOCATION, bq_client)
    
    return f"✅ Setup complete!\n\n📍 Project: {PROJECT_ID}\n📍 Location: {LOCATION}\n📍 Dataset: {USER_DATASET}\n\nLoaded {len(journal_impact_dict)} journals.\n{setup_message}", gr.update(interactive=True)

def run_analysis(case_text, num_articles, progress=gr.Progress()):
    if not genai_client or not bq_client:
        return None, "❌ Please complete setup first.", {}
    progress(0.1, desc="Extracting medical info...")
    medical_info = extract_medical_info(case_text, genai_client)
    disease = medical_info.get('disease', '')
    events = [e.strip() for e in medical_info.get('events', '').split(',')]
    
    progress(0.3, desc="Searching PubMed...")
    embedding_model_path = f"{PROJECT_ID}.{USER_DATASET}.textembed"
    articles_df = search_pubmed_articles(disease, events, bq_client, embedding_model_path, PUBMED_TABLE, num_articles)
    
    progress(0.6, desc="Analyzing articles...")
    analyses = analyze_article_batch(articles_df, disease, events, genai_client, journal_impact_dict)
    
    # Merge and score
    for i, analysis in enumerate(analyses):
        for k, v in analysis.items():
            articles_df.loc[i, k] = v
    
    scoring_config = SCORING_PRESETS["Clinical Focus"] # Simplified for this version
    articles_df['score'] = articles_df.apply(lambda row: calculate_article_score(row, scoring_config), axis=1)
    articles_df = articles_df.sort_values('score', ascending=False).reset_index()
    
    progress(0.9, desc="Generating results...")
    results_table = articles_df[['score', 'title', 'journal_title', 'year']].head(10)
    results = {'articles': articles_df.to_dict('records'), 'disease': disease, 'events': events, 'case_text': case_text}
    return results_table, f"✅ Analysis complete for '{disease}'.", results

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🏥 PubMed Medical Literature Analysis")
    app_state = gr.State({})
    missing_apis_state = gr.State([])

    with gr.Tabs():
        with gr.TabItem("1. Setup"):
            with gr.Row():
                project_dropdown = gr.Dropdown(label="Select Google Cloud Project", interactive=True)
                refresh_btn = gr.Button("🔄 Refresh Projects")
            
            location_dropdown = gr.Dropdown(label="Select Location", choices=VERTEX_AI_LOCATIONS, value="us-central1")
            dataset_input = gr.Textbox(label="BigQuery Dataset Name", value="pubmed")
            
            check_btn = gr.Button("Check Project Status", variant="secondary")
            status_output = gr.Markdown()
            
            with gr.Box(visible=False) as api_box:
                gr.Markdown("The following APIs need to be enabled. Click the button below to enable them.")
                enable_apis_btn = gr.Button("Enable Missing APIs", variant="primary")
            
            setup_btn = gr.Button("Complete Setup", variant="primary", interactive=False)
            setup_status = gr.Markdown()

        with gr.TabItem("2. Analyze Case"):
            case_input = gr.Textbox(label="Patient Case Notes", value=SAMPLE_CASE, lines=10)
            num_articles_slider = gr.Slider(5, 50, 10, step=1, label="Number of Articles to Analyze")
            analyze_btn = gr.Button("Run Full Analysis", variant="primary", interactive=False)
            analysis_status = gr.Markdown()

        with gr.TabItem("3. Results"):
            results_df = gr.DataFrame(label="Top 10 Ranked Articles")

    # Setup Tab Interactions
    refresh_btn.click(refresh_projects, outputs=[project_dropdown, status_output])
    check_btn.click(check_project_setup, inputs=[project_dropdown, location_dropdown, dataset_input], outputs=[status_output, api_box, setup_btn])
    enable_apis_btn.click(enable_missing_apis, inputs=[project_dropdown, missing_apis_state], outputs=[status_output])
    setup_btn.click(complete_setup, inputs=[project_dropdown, location_dropdown, dataset_input], outputs=[setup_status, analyze_btn])
    
    # Analysis Tab Interactions
    analyze_btn.click(run_analysis, inputs=[case_input, num_articles_slider], outputs=[results_df, analysis_status, app_state])
    
    # Initial load
    demo.load(refresh_projects, outputs=[project_dropdown, status_output])

demo.launch(share=True, debug=True)