In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# PubMed Medical Literature Analysis

<!-- [PLACEHOLDER: Update these links when notebook is finalized] -->

<table style="float: left; margin-right: 20px;">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Run in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FWandLZhang%2Fpubmed-rag%2Fmain%2FPubMed_RAG_Gradio.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/WandLZhang/pubmed-rag/main/PubMed_RAG_Gradio.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<!-- [PLACEHOLDER: Update share links when notebook is finalized] -->
<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Gradio.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>


| Authors |
| --- |
| [Willis Zhang](https://github.com/WandLZhang) |
| [Stone Jiang](https://github.com/siduojiang) |

## Overview

**Blog Post: Medical Literature Analysis with PubMed, BigQuery, Gemini**

<a href="[blog-post-url-placeholder]" target="_blank">
  <img src="https://storage.googleapis.com/[placeholder-image-path]/medical-literature-blog-header.jpg" alt="Medical Literature Analysis with PubMed and Gemini" width="500">
</a>

This notebook demonstrates how to analyze medical cases using PubMed literature with BigQuery vector search and Gemini. It converts the basic user experience from the [Capricorn Medical Research Application](https://capricorn-medical-research.web.app/) into an interactive Colab notebook.

## 🚀 Quick Start

1. Click **Runtime → Run all** (or press Ctrl/Cmd + F9)
2. Authenticate to your user account in the pop up
3. Continue in the embedded Gradio app or click the link that appears

This notebook analyzes medical cases using PubMed literature with BigQuery vector search and Gemini AI.

In [None]:
# @title 1️⃣ Install Dependencies { display-mode: "form" }
!pip install gradio google-genai google-cloud-bigquery google-cloud-resource-manager google-cloud-service-usage pandas plotly google-cloud-billing -q
print("✅ Dependencies installed")

In [None]:
# @title 2️⃣ Authenticate with Google Cloud { display-mode: "form" }
import sys
if "google.colab" in sys.modules:
    from google.colab import auth
    auth.authenticate_user()
    print("✅ Authenticated with Google Cloud")
else:
    print("ℹ️ Running locally - using default credentials")

In [None]:
# @title 3️⃣ Launch PubMed Analysis App { display-mode: "form" }

import gradio as gr
import pandas as pd
import json
import math
from datetime import datetime
import plotly.express as px
import plotly.graph_objects as go
from google import genai
from google.cloud import bigquery
from google.cloud import resourcemanager_v3
from google.cloud import service_usage_v1
from google.cloud import billing_v1
from google.genai.types import GenerateContentConfig
import time
import os
import webbrowser

# --- Constants ---
PUBMED_DATASET = "wz-data-catalog-demo.pubmed"
PUBMED_TABLE = f"{PUBMED_DATASET}.pmid_embed_nonzero_metadata"
MODEL_ID = "gemini-1.5-flash"
JOURNAL_IMPACT_CSV_URL = "https://raw.githubusercontent.com/WandLZhang/scimagojr_2024/main/scimagojr_2024.csv"
REQUIRED_APIS = ["aiplatform.googleapis.com", "bigquery.googleapis.com", "cloudresourcemanager.googleapis.com"]
CREATE_BILLING_ACCOUNT_URL = "https://console.cloud.google.com/billing/create?inv=1&invt=Ab4E_Q"
CREATE_BILLING_ACCOUNT_OPTION = "→ Create New Billing Account"
SAMPLE_CASE = """A 4-year-old male presents with a 3-week history of progressive fatigue, pallor, and easy bruising. \n
Physical examination reveals hepatosplenomegaly and scattered petechiae. \n
\n
Laboratory findings:\n
- WBC: 45,000/μL with 80% blasts\n
- Hemoglobin: 7.2 g/dL\n
- Platelets: 32,000/μL\n
\n
Flow cytometry: CD33+, CD13+, CD117+, CD34+, HLA-DR+, CD19-, CD3-\n
\n
Cytogenetics: 46,XY,t(9;11)(p21.3;q23.3)\n
Molecular: KMT2A-MLLT3 fusion detected, FLT3-ITD positive, NRAS G12D mutation\n
\n
Diagnosis: KMT2A-rearranged acute myeloid leukemia (AML)"""

# --- Global Variables ---
genai_client, bq_client = None, None
journal_impact_dict = {}
PROJECT_ID = ""
LOCATION = "global"
USER_DATASET = "pubmed"

# Scoring presets
SCORING_PRESETS = {
    "Clinical Focus": {
        "disease_match": 50,
        "treatment_efficacy": 50,
        "clinical_trial": 40
    }
}

# --- Helper Functions for Enhanced Setup ---
def list_projects():
    """List all available Google Cloud projects."""
    try:
        client = resourcemanager_v3.ProjectsClient()
        projects = []
        request = resourcemanager_v3.SearchProjectsRequest(query="")
        for project in client.search_projects(request=request):
            if project.state == resourcemanager_v3.Project.State.ACTIVE:
                projects.append({
                    "id": project.project_id,
                    "name": project.display_name,
                    "number": project.name.split('/')[-1]
                })
        return sorted(projects, key=lambda p: p['id'])
    except Exception as e:
        print(f"Error listing projects: {e}")
        return []

def check_billing_enabled(project_id):
    """Check if billing is enabled for a project."""
    try:
        client = billing_v1.CloudBillingClient()
        billing_info = client.get_project_billing_info(name=f"projects/{project_id}")
        return billing_info.billing_enabled
    except Exception as e:
        print(f"Could not check billing for project {project_id}: {e}")
        return False

def list_enabled_apis(project_id):
    """List enabled APIs for a project."""
    try:
        client = service_usage_v1.ServiceUsageClient()
        request = service_usage_v1.ListServicesRequest(
            parent=f"projects/{project_id}",
            filter="state:ENABLED"
        )
        enabled_apis = [service.name.split('/')[-1] for service in client.list_services(request=request)]
        return enabled_apis
    except Exception as e:
        print(f"Error listing APIs: {e}")
        return []

def enable_apis(project_id, apis_to_enable, progress=gr.Progress()):
    """Enable a list of APIs for a project."""
    client = service_usage_v1.ServiceUsageClient()
    total_apis = len(apis_to_enable)
    for i, api_name in enumerate(apis_to_enable):
        progress((i + 1) / total_apis, desc=f"Enabling {api_name}...")
        try:
            request = service_usage_v1.EnableServiceRequest(name=f"projects/{project_id}/services/{api_name}")
            operation = client.enable_service(request=request)
            operation.result(timeout=300)  # Wait for completion
        except Exception as e:
            raise RuntimeError(f"Error enabling API {api_name}: {e}")
    return True

def list_billing_accounts():
    """Lists available billing accounts and adds an option to create a new one."""
    try:
        client = billing_v1.CloudBillingClient()
        accounts = client.list_billing_accounts()
        account_names = [f"{acc.display_name} ({acc.name.split('/')[-1]})" for acc in accounts if acc.open]
        return account_names + [CREATE_BILLING_ACCOUNT_OPTION]
    except Exception as e:
        print(f"Error listing billing accounts: {e}")
        return [CREATE_BILLING_ACCOUNT_OPTION]

def create_new_project(project_id, billing_account_name, progress=gr.Progress()):
    """Creates a new GCP project, links billing, and enables necessary APIs."""
    try:
        progress(0.1, desc="Creating project...")
        project_client = resourcemanager_v3.ProjectsClient()
        project = {'project_id': project_id, 'display_name': project_id}
        operation = project_client.create_project(project=project)
        created_project = operation.result(timeout=300)

        progress(0.4, desc="Linking billing account...")
        billing_client = billing_v1.CloudBillingClient()
        billing_account_id = billing_account_name.split(' ')[-1].strip('()')
        project_billing_info = {'billing_account_name': f"billingAccounts/{billing_account_id}"}
        billing_client.update_project_billing_info(
            name=f"projects/{created_project.project_id}",
            project_billing_info=project_billing_info
        )

        progress(0.6, desc="Enabling APIs...")
        enable_apis(project_id, REQUIRED_APIS, progress)

        # Add a delay to ensure project propagation and IAM permissions
        progress(0.7, desc="Waiting for project propagation...")
        time.sleep(10)  # 10-second delay for IAM permissions to propagate

        # Use the shared setup logic
        global genai_client, bq_client, journal_impact_dict
        genai_client, bq_client, journal_impact_dict = setup_project(project_id, LOCATION, USER_DATASET, progress)

        return f"✅ Project '{project_id}' created and set up.", f"{project_id} ({project_id})"
    except Exception as e:
        return f"❌ Error creating project: {e}", None

def link_billing_to_project(project_id, billing_account_name):
    """Links an existing billing account to a project."""
    try:
        billing_client = billing_v1.CloudBillingClient()
        billing_account_id = billing_account_name.split(' ')[-1].strip('()')
        project_billing_info = {'billing_account_name': f"billingAccounts/{billing_account_id}"}
        billing_client.update_project_billing_info(
            name=f"projects/{project_id}",
            project_billing_info=project_billing_info
        )
        return True, "✅ Billing account linked successfully!"
    except Exception as e:
        return False, f"❌ Error linking billing account: {e}"

# --- Core Functions ---
def setup_project(project_id, location, dataset, progress=gr.Progress()):
    """Common setup logic for both new and existing projects."""
    try:
        # Set environment variable
        os.environ['GOOGLE_CLOUD_PROJECT'] = project_id
        
        progress(0.7, desc="Initializing clients...")
        genai_client, bq_client = init_clients(project_id, location)
        if not genai_client or not bq_client:
            raise ConnectionError("Failed to initialize Google Cloud clients.")

        # Setup BigQuery dataset and model
        setup_bigquery(project_id, dataset, bq_client, progress)

        progress(0.9, desc="Loading journal data...")
        journal_impact_dict = load_journal_data()
        
        return genai_client, bq_client, journal_impact_dict
    except Exception as e:
        raise e

def init_clients(project_id, location):
    """Initialize clients with retry logic for newly created projects."""
    max_retries = 3
    retry_delays = [5, 10, 15]  # Delays in seconds between retries
    
    # Ensure the project ID is set in the environment
    os.environ['GOOGLE_CLOUD_PROJECT'] = project_id
    
    for attempt in range(max_retries):
        try:
            print(f"Attempting to initialize clients for project {project_id} (attempt {attempt + 1}/{max_retries})...")
            
            genai_client = genai.Client(vertexai=True, project=project_id, location=location)
            bq_client = bigquery.Client(project=project_id)
            
            # Test BigQuery access
            test_query = "SELECT 1"
            bq_client.query(test_query).result()
            
            print(f"Successfully initialized clients for project {project_id}")
            return genai_client, bq_client
            
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            
            if attempt < max_retries - 1:
                delay = retry_delays[attempt]
                print(f"Waiting {delay} seconds before retry...")
                time.sleep(delay)
            else:
                print(f"All {max_retries} attempts failed. Error initializing clients for project {project_id}: {e}")
                return None, None

def load_journal_data():
    try:
        df = pd.read_csv(JOURNAL_IMPACT_CSV_URL, sep=';')
        return dict(zip(df['Title'], df['SJR']))
    except:
        return {}

def extract_medical_info(case_text, client):
    prompts = {
        "disease": "Extract the primary disease diagnosis. Return ONLY the name.",
        "events": "Extract all actionable medical events (mutations, biomarkers, etc.). Return a comma-separated list."
    }
    results = {}
    for key, prompt in prompts.items():
        full_prompt = f"{prompt}\n\nCase notes:\n{case_text}"
        response = client.models.generate_content(model=MODEL_ID, contents=[full_prompt], config=GenerateContentConfig(temperature=0))
        results[key] = response.text.strip()
    return results

def search_pubmed_articles(disease, events, bq_client, embedding_model, pubmed_table, top_k):
    query_text = f"{disease} {' '.join(events)}"
    # Debug: print the project being used
    print(f"Using BigQuery project: {bq_client.project}")
    print(f"Using embedding model: {embedding_model}")
    
    sql = f"""SELECT base.PMID, base.content, base.text as abstract, distance FROM VECTOR_SEARCH(TABLE `{pubmed_table}`, 'ml_generate_embedding_result', (SELECT ml_generate_embedding_result FROM ML.GENERATE_EMBEDDING(MODEL `{embedding_model}`, (SELECT \"{query_text}\" AS content))), top_k => {top_k})"""
    return bq_client.query(sql).to_dataframe()

def analyze_article_batch(df, disease, events, client, journal_dict):
    journal_context = "\n".join([f"- {title}: {sjr}" for title, sjr in journal_dict.items()])
    prompt = f"""Analyze articles for relevance to Disease: {disease} and Events: {', '.join(events)}. Use this data: {journal_context}. For each, extract: title, journal_title, journal_sjr, year, disease_match (bool), pediatric_focus (bool), treatment_shown (bool), paper_type, key_findings, clinical_trial (bool), novel_findings (bool). Return JSON array."""
    articles_text = ""
    for _, row in df.iterrows():
        content = row.get('content', row.get('abstract', ''))
        articles_text += f"\n---\nPMID: {row['PMID']}\nContent: {content}\n"
    response = client.models.generate_content(model=MODEL_ID, contents=[prompt + articles_text], config=GenerateContentConfig(temperature=0, response_mime_type="application/json"))
    try:
        return json.loads(response.text)
    except json.JSONDecodeError:
        return []

def calculate_article_score(metadata, config):
    score = 0
    if metadata.get('disease_match'): score += config.get('disease_match', 50)
    if metadata.get('treatment_shown'): score += config.get('treatment_efficacy', 50)
    if metadata.get('clinical_trial'): score += config.get('clinical_trial', 40)
    return round(score, 2)

def setup_bigquery(project, dataset, client, progress=gr.Progress()):
    """Setup BigQuery dataset and model with retry logic."""
    progress(0.8, desc="Setting up BigQuery dataset and model (may take a couple minutes if first time)...")
    
    # Create dataset if it doesn't exist
    try:
        client.get_dataset(f"{project}.{dataset}")
    except:
        client.create_dataset(bigquery.Dataset(f"{project}.{dataset}"), exists_ok=True)
    
    # Create model with retry logic
    model_query = f"CREATE MODEL IF NOT EXISTS `{project}.{dataset}.textembed` REMOTE WITH CONNECTION DEFAULT OPTIONS(endpoint='text-embedding-005');"
    
    max_retries = 3
    retry_delays = [5, 10, 15]
    
    for attempt in range(max_retries):
        try:
            print(f"Creating BigQuery embedding model (attempt {attempt + 1}/{max_retries})...")
            client.query(model_query).result()
            print(f"Successfully created BigQuery model for {project}.{dataset}")
            return f"✅ BigQuery setup complete for {project}.{dataset}"
            
        except Exception as e:
            error_msg = str(e)
            print(f"Attempt {attempt + 1} failed: {error_msg}")
            
            # Check if it's a job execution error that might be timing-related
            if "internal error during execution" in error_msg.lower() and attempt < max_retries - 1:
                delay = retry_delays[attempt]
                print(f"This appears to be a timing issue. Waiting {delay} seconds before retry...")
                time.sleep(delay)
            elif attempt < max_retries - 1:
                # For other errors, also retry but with shorter delay
                delay = retry_delays[attempt] // 2
                print(f"Waiting {delay} seconds before retry...")
                time.sleep(delay)
            else:
                # All retries exhausted
                print(f"All {max_retries} attempts failed.")
                raise Exception(f"Failed to create BigQuery model after {max_retries} attempts. Last error: {error_msg}")

# --- Gradio App Logic ---
def get_initial_projects():
    """Get the list of projects for the dropdown."""
    projects = list_projects()
    if not projects:
        return gr.update(choices=[], value=None), "❌ No projects found. Please ensure you have access to at least one GCP project."
    choices = [f"{p['name']} ({p['id']})" for p in projects]
    return gr.update(choices=choices, value=choices[0] if choices else None), f"✅ Found {len(projects)} projects. Select a project and click Proceed."

def proceed_with_project(project_selection, progress=gr.Progress()):
    """Check and set up the selected project, then move to the next tab."""
    global genai_client, bq_client, journal_impact_dict, PROJECT_ID, LOCATION, USER_DATASET
    if not project_selection:
        return "❌ Please select a project first.", gr.update(interactive=False), gr.update()

    project_id = project_selection.split('(')[-1].rstrip(')')
    PROJECT_ID = project_id
    
    # Clear any existing project environment variable
    if 'GOOGLE_CLOUD_PROJECT' in os.environ:
        del os.environ['GOOGLE_CLOUD_PROJECT']
    
    # Set the new project ID
    os.environ['GOOGLE_CLOUD_PROJECT'] = project_id

    try:
        progress(0.1, desc="Checking billing status...")
        if not check_billing_enabled(project_id):
            # Return special status to trigger billing setup
            return "billing_needed", gr.update(interactive=False), gr.update()

        progress(0.2, desc="Checking required APIs...")
        enabled_apis = list_enabled_apis(project_id)
        missing_apis = [api for api in REQUIRED_APIS if api not in enabled_apis]
        if missing_apis:
            enable_apis(project_id, missing_apis, progress)

        # Use the shared setup logic
        genai_client, bq_client, journal_impact_dict = setup_project(PROJECT_ID, LOCATION, USER_DATASET, progress)

        status = f"✅ Setup complete for {PROJECT_ID}! You can now analyze a case."
        return status, gr.update(interactive=True), gr.update(selected=2)

    except Exception as e:
        return f"❌ Error: {e}", gr.update(interactive=False), gr.update()

def run_analysis(case_text, num_articles, progress=gr.Progress()):
    if not genai_client or not bq_client:
        return None, "❌ Please complete setup first.", {}, gr.update()
    progress(0.1, desc="Extracting medical info...")
    medical_info = extract_medical_info(case_text, genai_client)
    disease = medical_info.get('disease', '')
    events = [e.strip() for e in medical_info.get('events', '').split(',')]

    progress(0.3, desc="Searching PubMed...")
    embedding_model_path = f"{PROJECT_ID}.{USER_DATASET}.textembed"
    articles_df = search_pubmed_articles(disease, events, bq_client, embedding_model_path, PUBMED_TABLE, num_articles)

    progress(0.6, desc="Analyzing articles...")
    analyses = analyze_article_batch(articles_df, disease, events, genai_client, journal_impact_dict)

    for i, analysis in enumerate(analyses):
        for k, v in analysis.items():
            articles_df.loc[i, k] = v

    scoring_config = SCORING_PRESETS["Clinical Focus"]
    articles_df['score'] = articles_df.apply(lambda row: calculate_article_score(row, scoring_config), axis=1)
    articles_df = articles_df.sort_values('score', ascending=False).reset_index()

    progress(0.9, desc="Generating results...")
    results_table = articles_df[['score', 'title', 'journal_title', 'year']].head(10)
    results = {'articles': articles_df.to_dict('records'), 'disease': disease, 'events': events, 'case_text': case_text}
    return results_table, f"✅ Analysis complete for '{disease}'.", results, gr.update(selected=3)

css = """
.gradio-container { font-family: 'Google Sans', sans-serif; }
label, .label-wrap, .gradio-label { 
    background-color: transparent !important; 
    border: none !important; 
    box-shadow: none !important; 
    padding: 0 !important; 
}
.label-wrap {
    border: none !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=css) as demo:
    gr.Markdown("# 🏥 PubMed Literature Analysis")
    app_state = gr.State({})

    with gr.Tabs() as tabs:
        with gr.TabItem("Get Started", id=0):
            gr.Markdown("## Welcome to the PubMed Literature Analysis Tool")
            gr.Markdown("This tool helps you analyze medical cases using PubMed literature with BigQuery vector search and Gemini. Get started by setting up your Google Cloud project.")
            gr.Markdown("""⚠️ **Important Notice: Demonstration Tool**\n\nThis PubMed literature analysis tool is a **DEMONSTRATION** showcasing AI-powered research capabilities.\n\n- For **research and educational purposes only**\n- **NOT** intended for treatment planning or clinical decisions\n- All AI-generated analyses should be verified against primary sources\n- Results may contain inaccuracies or limitations\n- Users are responsible for appropriate use within research contexts\n\nBy proceeding, you acknowledge these limitations and agree to use this tool responsibly for research purposes only.\n""")
            start_button = gr.Button("Get Started", variant="primary")

        with gr.TabItem("1. Setup", id=1):
            status_output = gr.Markdown(value="Loading projects...")
            with gr.Row():
                project_dropdown = gr.Dropdown(label="Select Google Cloud Project", interactive=True)
                create_project_btn = gr.Button("Create New Project")

            with gr.Column(visible=False) as create_project_box:
                gr.Markdown("### Create New Google Cloud Project")
                new_project_id_input = gr.Textbox(label="New Project ID", placeholder="e.g., pubmed-analysis-123")
                billing_account_dropdown = gr.Dropdown(label="Select Billing Account")
                billing_link_message = gr.Markdown(visible=False)
                create_project_submit_btn = gr.Button("Create and Select Project", variant="primary")
                cancel_create_project_btn = gr.Button("Cancel")

            with gr.Column(visible=False) as billing_setup_box:
                gr.Markdown("### 💳 Billing Setup Required")
                gr.Markdown("This project needs a billing account to use Google Cloud services.")
                billing_setup_dropdown = gr.Dropdown(label="Select Billing Account")
                billing_setup_message = gr.Markdown(visible=False)
                link_billing_btn = gr.Button("Link Billing Account", variant="primary")
                billing_status = gr.Markdown()

            with gr.Column() as setup_details_box:
                proceed_btn = gr.Button("Proceed", variant="primary")

        with gr.TabItem("2. Case", id=2):
            case_input = gr.Textbox(label="Patient Case Notes", value=SAMPLE_CASE, lines=10)
            num_articles_slider = gr.Slider(5, 50, 10, step=1, label="Number of Articles to Analyze")
            analyze_btn = gr.Button("Run Full Analysis", variant="primary", interactive=False)
            analysis_status = gr.Markdown()

        with gr.TabItem("3. Results", id=3):
            results_df = gr.DataFrame(label="Top 10 Ranked Articles")

    # --- Event Handlers for UI ---
    def show_create_project_form():
        accounts = list_billing_accounts()
        return gr.update(visible=True), gr.update(choices=accounts, value=accounts[0] if accounts else None), gr.update(visible=False)

    def hide_create_project_form():
        return gr.update(visible=False), gr.update(visible=True)

    def handle_billing_selection(billing_account):
        if billing_account is None:
            # Don't change anything when None is selected
            return gr.update(), gr.update()
        if billing_account == CREATE_BILLING_ACCOUNT_OPTION:
            # Clear the dropdown selection and return a status message
            status_msg = f"\n\n📋 **To create a billing account:**\n\n1. Open this link in your browser: {CREATE_BILLING_ACCOUNT_URL}\n2. Complete the billing account setup\n3. Restart the Gradio app and select your new billing account from the dropdown\n\n"
            return gr.update(value=None), gr.update(value=status_msg, visible=True)
        # Valid billing account selected, hide the message
        return gr.update(), gr.update(visible=False)

    def handle_project_creation(project_id, billing_account, progress=gr.Progress()):
        status, new_project_selection = create_new_project(project_id, billing_account, progress)
        if new_project_selection:
            projects = list_projects()
            choices = [f"{p['name']} ({p['id']})" for p in projects]
            return gr.update(visible=False), gr.update(visible=True), gr.update(choices=choices, value=new_project_selection), status, gr.update(selected=2)
        return gr.update(), gr.update(), gr.update(), status, gr.update()

    # Tab Switching
    start_button.click(lambda: gr.update(selected=1), None, tabs)

    # Setup Tab Interactions
    create_project_btn.click(show_create_project_form, outputs=[create_project_box, billing_account_dropdown, setup_details_box])
    cancel_create_project_btn.click(hide_create_project_form, outputs=[create_project_box, setup_details_box])
    billing_account_dropdown.change(handle_billing_selection, inputs=[billing_account_dropdown], outputs=[billing_account_dropdown, billing_link_message])
    create_project_submit_btn.click(
        handle_project_creation, 
        inputs=[new_project_id_input, billing_account_dropdown], 
        outputs=[create_project_box, setup_details_box, project_dropdown, status_output, tabs]
    )
    def handle_billing_setup_selection(billing_account):
        """Handle billing account selection in the billing setup box."""
        if billing_account is None:
            # Don't change anything when None is selected
            return gr.update(), gr.update()
        if billing_account == CREATE_BILLING_ACCOUNT_OPTION:
            # Clear the dropdown selection and return a status message
            status_msg = f"\n\n📋 **To create a billing account:**\n\n1. Open this link in your browser: {CREATE_BILLING_ACCOUNT_URL}\n2. Complete the billing account setup\n3. Restart the Gradio app and select your new billing account from the dropdown\n\n"
            return gr.update(value=None), gr.update(value=status_msg, visible=True)
        # Valid billing account selected, hide the message
        return gr.update(), gr.update(visible=False)

    def handle_link_billing(billing_account, project_dropdown, progress=gr.Progress()):
        """Handle linking billing account to the project."""
        if not billing_account or billing_account == CREATE_BILLING_ACCOUNT_OPTION:
            return "❌ Please select a valid billing account.", gr.update(visible=True), gr.update(visible=False)
        
        project_id = project_dropdown.split('(')[-1].rstrip(')')
        progress(0.1, desc="Linking billing account...")
        
        success, message = link_billing_to_project(project_id, billing_account)
        if success:
            progress(0.3, desc="Billing linked! Continuing setup...")
            # After successful billing link, continue with the normal setup
            status, analyze_btn_update, tabs_update = proceed_with_project(project_dropdown, progress)
            # Return appropriate updates for this function's outputs
            # The .then() chains will handle the analyze button and tabs updates based on the status message
            return status, gr.update(visible=False), gr.update(visible=True)
        else:
            return message, gr.update(visible=True), gr.update(visible=False)

    # State to track if we need billing setup
    needs_billing_setup = gr.State(False)
    
    # Modified proceed button click handler
    def handle_proceed_click(project_dropdown, progress=gr.Progress()):
        """Handle the proceed button click."""
        status, analyze_btn_update, tabs_update = proceed_with_project(project_dropdown, progress)
        
        if status == "billing_needed":
            # Show billing setup box and populate dropdown
            accounts = list_billing_accounts()
            return (
                "❌ Billing is not enabled for this project. Please set up billing to continue.",
                gr.update(interactive=False),  # analyze_btn
                gr.update(),  # tabs (no change)
                gr.update(visible=True),  # billing_setup_box
                gr.update(visible=False),  # setup_details_box
                gr.update(choices=accounts, value=accounts[0] if accounts else None),  # billing_setup_dropdown
                True  # needs_billing_setup state
            )
        else:
            # Normal flow
            return (
                status,
                analyze_btn_update,
                tabs_update,
                gr.update(visible=False),  # billing_setup_box
                gr.update(visible=True),  # setup_details_box
                gr.update(),  # billing_setup_dropdown (no change)
                False  # needs_billing_setup state
            )
    
    proceed_btn.click(
        handle_proceed_click, 
        inputs=[project_dropdown], 
        outputs=[status_output, analyze_btn, tabs, billing_setup_box, setup_details_box, billing_setup_dropdown, needs_billing_setup]
    )

    # Billing setup handlers
    billing_setup_dropdown.change(
        handle_billing_setup_selection, 
        inputs=[billing_setup_dropdown], 
        outputs=[billing_setup_dropdown, billing_setup_message]
    )
    
    # Helper functions for the .then() chains
    def update_analyze_btn_based_on_status(status_markdown):
        """Update analyze button based on the status message."""
        # Extract the actual text value from the Markdown component data
        if isinstance(status_markdown, dict) and 'value' in status_markdown:
            status_text = status_markdown['value']
        elif isinstance(status_markdown, str):
            status_text = status_markdown
        else:
            status_text = str(status_markdown)
        
        return gr.update(interactive=status_text.startswith("✅"))
    
    def update_tabs_based_on_status(status_markdown):
        """Update tabs based on the status message."""
        # Extract the actual text value from the Markdown component data
        if isinstance(status_markdown, dict) and 'value' in status_markdown:
            status_text = status_markdown['value']
        elif isinstance(status_markdown, str):
            status_text = status_markdown
        else:
            status_text = str(status_markdown)
        
        if status_text.startswith("✅"):
            return gr.update(selected=2)
        else:
            return gr.update()
    
    link_billing_output = link_billing_btn.click(
        handle_link_billing,
        inputs=[billing_setup_dropdown, project_dropdown],
        outputs=[status_output, billing_setup_box, setup_details_box]
    )
    
    # Update analyze button based on the status
    link_billing_output.then(
        update_analyze_btn_based_on_status,
        inputs=[status_output],
        outputs=[analyze_btn]
    )
    
    # Update tabs based on the status
    link_billing_output.then(
        update_tabs_based_on_status,
        inputs=[status_output],
        outputs=[tabs]
    )

    # Analysis Tab Interactions
    analyze_btn.click(
        run_analysis, 
        inputs=[case_input, num_articles_slider], 
        outputs=[results_df, analysis_status, app_state, tabs]
    )

    # Initial load
    demo.load(get_initial_projects, outputs=[project_dropdown, status_output])

demo.launch(share=True, debug=True)