In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# PubMed Medical Literature Analysis

<!-- [PLACEHOLDER: Update these links when notebook is finalized] -->
<table style="float: left; margin-right: 20px;">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Run in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FWandLZhang%2Fpubmed-rag%2Fmain%2FPubMed_RAG_Example.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/WandLZhang/pubmed-rag/main/PubMed_RAG_Example.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<!-- [PLACEHOLDER: Update share links when notebook is finalized] -->
<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>            

| Authors |
| --- |
| [Willis Zhang](https://github.com/WandLZhang) |
| [Stone Jiang](https://github.com/siduojiang) |

## Overview

**Blog Post: Medical Literature Analysis with PubMed, BigQuery, Gemini**

<a href="[blog-post-url-placeholder]" target="_blank">
  <img src="https://storage.googleapis.com/[placeholder-image-path]/medical-literature-blog-header.jpg" alt="Medical Literature Analysis with PubMed and Gemini" width="500">
</a>

This notebook demonstrates how to analyze medical cases using PubMed literature with BigQuery vector search and Gemini. It converts the basic user experience from the [Capricorn Medical Research Application](https://capricorn-medical-research.web.app/) into an interactive Colab notebook.


In this tutorial, you learn how to:

- Extract medical information (disease diagnosis and actionable events) from case notes
- Search PubMed literature using BigQuery vector search
- Score and rank articles using customizable criteria
- Develop evidence-based analysis with citations
- Create an interactive chat interface for medical discussions

![Medical Literature Analysis Architecture](https://github.com/WandLZhang/pubmed-rag/blob/main/visuals/1.png?raw=true)

This tutorial uses the following Google Cloud AI services and resources:

- **Vertex AI**: Gemini 2.5 Flash for text analysis and generation
- **BigQuery**: Vector search on PubMed article embeddings
- **Interactive Widgets**: Customizable scoring configuration

## Let's begin

1. If you are running this notebook locally, you need to install the [Cloud SDK](https://cloud.google.com/sdk).
2. Install the following packages required to execute this notebook.

In [None]:
%pip install --upgrade --quiet google-genai google-cloud-bigquery google-cloud-bigquery-storage plotly pandas==2.2.2 db-dtypes

In [None]:
import sys

if "google.colab" in sys.modules:

    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

3. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.

4. After selecting your project, in the console click the main logo in the top-left to get to your project home.

![](https://github.com/WandLZhang/pubmed-rag/blob/main/visuals/2.png?raw=true)

This will take you to your project home. Copy the `Project ID` like the above orange box and paste it into the field below:

In [None]:
import os

PROJECT_ID = "wz-data-catalog-demo"  # @param {type: "string"}
if not PROJECT_ID or PROJECT_ID == "[your-project-id]":
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

NOTE: You can change the `LOCATION` variable used by Vertex AI. Learn more about [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations).

5. Enable the [Vertex AI APIs](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com,discoveryengine.googleapis.com).
6. If you are running this notebook on Google Colab, you will need to authenticate your environment. To do this, run the new cell below. This step is not required if you are using Vertex AI Workbench.

In [None]:
import sys

if "google.colab" in sys.modules:
    # Authenticate user to Google Cloud
    from google.colab import auth

    auth.authenticate_user()

## Medical Literature Analysis Pipeline

This section implements the complete medical literature analysis workflow, from case notes to treatment recommendations.

### 1. Initialize Vertex AI and BigQuery Configuration

In [None]:
# Model Configuration
MODEL_ID = "gemini-2.5-flash" # @param ["gemini-2.5-flash","gemini-2.5-pro"] {"allow-input":true, isTemplate: true}
THINKING_BUDGET = 0 # @param {type: "slider", min: 0, max: 24576, step: 1}

# Initialize the Gemini model from Vertex AI:
from google import genai
from google.genai import types

client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)

# Initialize BigQuery client
from google.cloud import bigquery
bq_client = bigquery.Client(project=PROJECT_ID)

# Configure PubMed dataset (public dataset with embeddings)
# [PLACEHOLDER: Update to final public dataset location]
PUBMED_DATASET = "wz-data-catalog-demo.pubmed"
PUBMED_TABLE = f"{PUBMED_DATASET}.pmid_embed_nonzero_metadata"  # Combined embeddings and metadata table

# User's BigQuery dataset for embedding model
USER_DATASET = "pubmed"  # @param {type: "string"}
EMBEDDING_MODEL = f"{PROJECT_ID}.{USER_DATASET}.textembed"  # Text embedding model for vector search

# Create the dataset if it doesn't exist
try:
    # Check if dataset exists
    dataset_ref = bq_client.dataset(USER_DATASET)
    dataset = bq_client.get_dataset(dataset_ref)
    print(f"✅ Dataset '{USER_DATASET}' already exists")
except:
    # Create the dataset
    dataset = bigquery.Dataset(f"{PROJECT_ID}.{USER_DATASET}")
    dataset.location = LOCATION
    dataset = bq_client.create_dataset(dataset, exists_ok=True)
    print(f"✅ Created dataset '{USER_DATASET}'")

# Journal impact data will be loaded from CSV [PLACEHOLDER]
JOURNAL_IMPACT_CSV_URL = "https://raw.githubusercontent.com/WandLZhang/scimagojr_2024/main/scimagojr_2024.csv"

### 2. Create Text Embedding Model

Before running vector searches, we'll create a text embedding model in BigQuery.

In [None]:
# Use CREATE MODEL IF NOT EXISTS for simplicity
create_model_query = f"""
CREATE MODEL IF NOT EXISTS `{EMBEDDING_MODEL}`
  REMOTE WITH CONNECTION DEFAULT
  OPTIONS(endpoint='text-embedding-005');
"""

try:
    query_job = bq_client.query(create_model_query)
    query_job.result()  # Wait for the query to complete
    print(f"✅ Successfully created/verified embedding model: {EMBEDDING_MODEL}")
except Exception as e:
    print(f"❌ Failed to create embedding model: {str(e)}")

### 3. Setup Journal Impact Data in BigQuery

We'll create a BigQuery table for journal impact data, which helps us score article sources.

In [None]:
import pandas as pd

# Create Journal Impact Table in BigQuery
def setup_journal_impact_table():
    """Create and populate journal impact table if it doesn't exist."""
    table_id = "journal_impact"
    table_ref = f"{PROJECT_ID}.{USER_DATASET}.{table_id}"
    
    try:
        # Check if table exists
        try:
            table = bq_client.get_table(table_ref)
            print(f"✅ Journal impact table already exists with {table.num_rows} rows")
            return True
        except:
            # Table doesn't exist, create it
            print(f"📊 Creating journal impact table: {table_ref}")
            
            # Download and parse CSV
            df = pd.read_csv(JOURNAL_IMPACT_CSV_URL, sep=';')
            
            # Convert SJR values from string with commas to float
            df['SJR_float'] = df['SJR'].apply(lambda x: float(str(x).replace(',', '')) if pd.notna(x) and str(x) != '' else None)
            
            # Select relevant columns and rename
            columns_to_keep = {
                'Title': 'journal_title',
                'SJR_float': 'sjr',
                'Issn': 'issn',
                'SJR Best Quartile': 'sjr_best_quartile',
                'H index': 'h_index',
                'Publisher': 'publisher',
                'Categories': 'categories',
                'Country': 'country',
                'Type': 'type'
            }
            
            df_clean = df[list(columns_to_keep.keys())].rename(columns=columns_to_keep)
            
            # Remove rows with no SJR value
            df_clean = df_clean[df_clean['sjr'].notna()]
            
            print(f"📈 Cleaned data: {len(df_clean)} rows with valid SJR values")
            
            # Define table schema
            schema = [
                bigquery.SchemaField("journal_title", "STRING"),
                bigquery.SchemaField("sjr", "FLOAT64"),
                bigquery.SchemaField("issn", "STRING"),
                bigquery.SchemaField("sjr_best_quartile", "STRING"),
                bigquery.SchemaField("h_index", "INT64"),
                bigquery.SchemaField("publisher", "STRING"),
                bigquery.SchemaField("categories", "STRING"),
                bigquery.SchemaField("country", "STRING"),
                bigquery.SchemaField("type", "STRING"),
            ]
            
            # Configure load job
            job_config = bigquery.LoadJobConfig(
                schema=schema,
                write_disposition="WRITE_TRUNCATE",
            )
            
            # Load data
            print(f"⬆️ Uploading {len(df_clean)} rows to {table_ref}...")
            job = bq_client.load_table_from_dataframe(df_clean, table_ref, job_config=job_config)
            job.result()  # Wait for job to complete
            
            # Verify upload
            table = bq_client.get_table(table_ref)
            print(f"✅ Successfully created journal impact table with {table.num_rows} rows")
            return True
            
    except Exception as e:
        print(f"❌ Error setting up journal impact table: {e}")
        # Continue without journal impact data
        return False

# Setup the journal impact table
setup_journal_impact_table()

# Load journal data from BigQuery for local lookups
def load_journal_data_from_bigquery():
    """Load journal data from BigQuery table."""
    try:
        query = f"""
        SELECT 
            journal_title,
            sjr
        FROM `{PROJECT_ID}.{USER_DATASET}.journal_impact`
        WHERE sjr IS NOT NULL
        ORDER BY sjr DESC
        """
        
        print("📥 Loading journal data from BigQuery...")
        results = bq_client.query(query).to_dataframe()
        
        # Convert to dictionary
        journal_dict = dict(zip(results['journal_title'], results['sjr']))
        print(f"✅ Loaded {len(journal_dict)} journals from BigQuery")
        return journal_dict
        
    except Exception as e:
        print(f"Error loading journal data from BigQuery: {e}")
        return {}

# Load the journal impact dictionary
journal_impact_dict = load_journal_data_from_bigquery()

### 3. Customizable Scoring System

Configure how articles are scored based on various factors. Adjust the sliders to match your research priorities.

In [None]:
# Dynamic Scoring Configuration System

class DynamicScoringConfig:
    """Configuration class for dynamic scoring criteria."""
    
    def __init__(self, criteria_list):
        """
        Initialize with a list of criteria dictionaries.
        Each criterion should have: name, description, type, weight
        """
        self.criteria_list = criteria_list
        self.config = {c['name']: c['weight'] for c in criteria_list}
        self.categories = {c['name']: {'description': c['description']} for c in criteria_list}
        self.criteria_by_name = {c['name']: c for c in criteria_list}
        
        # Separate special criteria that need custom handling
        self.special_criteria = ['journal_impact', 'year_penalty', 'event_match']
        
        # For compatibility with existing code
        self.default_categories = list(self.config.keys())
        
    def get_config(self):
        """Return current configuration as dictionary."""
        return self.config
    
    def get_criteria(self):
        """Get criteria for analysis (excluding special ones)."""
        return [c for c in self.criteria_list if c['type'] != 'special']
    
    def get_all_criteria(self):
        """Get all criteria including special ones."""
        return self.criteria_list
    
    def get_criterion(self, name):
        """Get a specific criterion by name."""
        return self.criteria_by_name.get(name, {})
    
    def get_category_types(self):
        """Return empty dict - no hardcoded category types."""
        # This method exists for compatibility but returns empty
        # since we're not using category types anymore
        return {}

# Note: The actual scoring_config will be initialized later in Step 1 with user-defined criteria
print("✅ Dynamic scoring system loaded. Criteria will be defined in Step 1: Run the Analysis Pipeline.")


### 4. Analysis Persona Configuration

Customize your research perspective to tailor how articles are analyzed.

In [None]:
# Analysis Persona Configuration
ANALYSIS_PERSONA = "You are a medical researcher analyzing literature for clinical relevance and treatment insights." # @param {type: "string"}

# Simple persona configuration class for compatibility
class SimplePersonaConfig:
    def __init__(self, persona):
        self.persona = persona
    
    def get_persona(self):
        return self.persona

# Initialize persona configuration  
persona_config = SimplePersonaConfig(ANALYSIS_PERSONA)
print("✅ Analysis persona configured")


### 5. Medical Information Extraction Functions

In [None]:
from google.genai.types import GenerateContentConfig

# Default extraction prompts
DISEASE_EXTRACTION_PROMPT = """You are an expert pediatric oncologist analyzing patient case notes to identify the primary disease.

Task: Extract the initial diagnosis exactly as written in the case notes.

Examples:
- Input: "A now almost 4-year-old female diagnosed with KMT2A-rearranged AML and CNS2 involvement..."
  Output: AML

- Input: "18 y/o boy, diagnosed in November 2021 with T-ALL with CNS1..."
  Output: T-ALL

- Input: "A 10-year-old patient with relapsed B-cell acute lymphoblastic leukemia (B-ALL)..."
  Output: B-cell acute lymphoblastic leukemia (B-ALL)

Output only the disease name. No additional text or formatting."""

EVENT_EXTRACTION_PROMPT = """You are an expert pediatric oncologist analyzing patient case notes to identify key disease concepts and clinical features for literature search.

Task: Extract 5 general medical concepts that would help find relevant literature. Focus on:
- Disease types and subtypes (e.g., "AML", "T-ALL", "B-ALL")
- Genetic alterations (gene names only, e.g., "KMT2A rearrangement", "FLT3 mutation", "TP53 mutation")
- Treatment modalities (e.g., "HSCT", "chemotherapy", "CAR-T therapy", "stem cell transplant")
- General complications (e.g., "relapse", "refractory disease", "CNS involvement", "MRD positive")
- Anatomical sites or disease features (e.g., "bone marrow", "extramedullary disease")

Instructions:
- Extract GENERAL CONCEPTS that appear in medical literature
- DO NOT include patient-specific details like percentages, timeframes, or specific protocol names
- Focus on searchable medical terms
- Output exactly 5 concepts

Example:
Input: "A 4-year-old female with KMT2A-rearranged AML and CNS2 involvement exhibited refractory disease after NOPHO protocol. MRD remained at 35%. She relapsed 10 months after cord blood HSCT with 33% blasts. WES showed KMT2A::MLLT3 fusion and NRAS mutation."

Output: "AML" "KMT2A rearrangement" "CNS involvement" "refractory disease" "HSCT relapse"

Output only 5 general medical concepts, one per line in quotes. No additional text or formatting."""

def extract_medical_info(case_text, info_type="both", disease_prompt=None, events_prompt=None):
    """Extract disease and actionable events from case notes with customizable prompts."""
    
    # Use provided prompts or defaults
    prompts = {
        "disease": disease_prompt or DISEASE_EXTRACTION_PROMPT,
        "events": events_prompt or EVENT_EXTRACTION_PROMPT
    }
    
    results = {}
    
    for key, prompt in prompts.items():
        if info_type == "both" or info_type == key:
            full_prompt = f"{prompt}\n\nCase notes:\n{case_text}"
            
            # Match the original implementation - no max_output_tokens specified
            response = client.models.generate_content(
                model=MODEL_ID,
                contents=[full_prompt],
                config=GenerateContentConfig(
                    temperature=0,
                    thinking_config=types.ThinkingConfig(thinking_budget=THINKING_BUDGET)
                )
            )
            
            results[key] = response.text.strip()
    
    # Process events to create ID mapping
    if 'events' in results:
        events_text = results['events']
        events_list = []
        events_with_ids = {}
        
        # Parse events (handle both line-separated and quote-separated formats)
        if '"' in events_text:
            # Events are in quotes
            import re
            events_list = re.findall(r'"([^"]+)"', events_text)
        else:
            # Events are line-separated or comma-separated
            events_list = [e.strip() for e in events_text.replace('\n', ',').split(',') if e.strip()]
        
        # Create ID mapping
        for i, event in enumerate(events_list, 1):
            event_id = f"event_{i}"
            events_with_ids[event_id] = event
        
        results['events_list'] = events_list
        results['events_with_ids'] = events_with_ids
    
    return results

### 5. Article Scoring Functions

In [None]:
import math
from datetime import datetime

def normalize_journal_score(sjr, max_points):
    """Normalize journal SJR score to specified max points."""
    if not sjr or sjr <= 0:
        return 0
    
    # Use log scale to handle large range of SJR values
    # Typical SJR ranges from 0 to ~100,000
    normalized = math.log(sjr + 1) * (max_points / 12)  # log(100000) ≈ 11.5
    return min(normalized, max_points)

def calculate_article_score(metadata, config, query_disease=None):
    """Calculate article score based on metadata and dynamic user configuration."""
    score = 0
    breakdown = {}
    
    # Get the scoring config instance to access criteria details
    if hasattr(config, 'get_all_criteria'):
        # If config is a DynamicScoringConfig instance
        all_criteria = config.get_all_criteria()
        weights = config.get_config()
    else:
        # If config is just a dictionary (for backwards compatibility)
        weights = config
        all_criteria = [{'name': k, 'type': 'boolean', 'weight': v} for k, v in weights.items()]
    
    # Process each criterion
    for criterion in all_criteria:
        name = criterion['name']
        weight = criterion['weight']
        criterion_type = criterion.get('type', 'boolean')
        
        # Special handling for specific criteria that need computation
        if name == 'journal_impact' and metadata.get('journal_sjr'):
            sjr = float(metadata['journal_sjr'])
            if sjr > 0:
                impact_points = normalize_journal_score(sjr, weight)
                score += impact_points
                breakdown['journal_impact'] = round(impact_points, 2)
                
        elif name == 'year_penalty' and metadata.get('year'):
            try:
                current_year = datetime.now().year
                article_year = int(metadata['year'])
                year_diff = current_year - article_year
                year_points = weight * year_diff
                score += year_points
                breakdown['year'] = year_points
            except:
                pass
                
        elif name == 'event_match':
            # Handle event matching
            events = metadata.get('actionable_events', [])
            if events is None:
                events = []
            
            # Handle different formats of actionable_events
            if isinstance(events, str):
                # If it's a comma-separated string
                events = [e.strip() for e in events.split(',') if e.strip()]
                matched_events = len(events)
            elif isinstance(events, list):
                # If it's a list of dicts with 'matches_query' field
                matched_events = sum(1 for event in events if isinstance(event, dict) and event.get('matches_query', False))
            else:
                matched_events = 0
                
            if matched_events > 0:
                event_points = matched_events * weight
                score += event_points
                breakdown['event_match'] = event_points
                
        # Generic handling for all other criteria
        elif criterion_type == 'boolean' and metadata.get(name):
            score += weight
            breakdown[name] = weight
            
        elif criterion_type == 'numeric':
            value = metadata.get(name, 0)
            if value:
                # Assume numeric values are 0-100 scale
                points = weight * (value / 100)
                score += points
                breakdown[name] = round(points, 2)
                
        elif criterion_type == 'direct':
            value = metadata.get(name, 0)
            if value:
                # Direct multiplication
                points = weight * value
                score += points
                breakdown[name] = round(points, 2)
    
    return round(score, 2), breakdown

### 6. BigQuery Vector Search Functions

In [None]:
def generate_embedding(text):
    """Generate text embedding using Gemini."""
    from vertexai.language_models import TextEmbeddingModel
    
    model = TextEmbeddingModel.from_pretrained("text-embedding-005")
    embeddings = model.get_embeddings([text])
    return embeddings[0].values

def search_pubmed_articles(disease, events_list, top_k=15, offset=0):
    """Search PubMed articles using BigQuery vector similarity."""
    
    # Combine disease and events for search query
    query_text = f"{disease} {' '.join(events_list)}"
    
    # Create the SQL query with offset support
    sql = f"""
    DECLARE query_text STRING;
    SET query_text = \"\"\"
{query_text}
\"\"\";
    
    WITH vector_results AS (
        SELECT base.name AS PMCID, base.PMID, base.content, distance 
        FROM VECTOR_SEARCH(
            TABLE `{PUBMED_TABLE}`, 
            'ml_generate_embedding_result', 
            (SELECT ml_generate_embedding_result 
             FROM ML.GENERATE_EMBEDDING(
                 MODEL `{EMBEDDING_MODEL}`, 
                 (SELECT query_text AS content)
             )), 
            top_k => {top_k + offset}
        )
    )
    SELECT * FROM vector_results
    ORDER BY distance
    LIMIT {top_k}
    OFFSET {offset}
    """
    
    # Execute query
    results = bq_client.query(sql).to_dataframe()
    
    return results

### 7. Article Analysis with Custom Categories

In [None]:
import json
import time

def analyze_article_batch(articles_df, disease, events_list, scoring_config):
    """Analyze a batch of articles using Gemini with dynamic criteria."""
    
    # Build journal context for Gemini to look up journal titles
    journal_context = ""
    for title, sjr in journal_impact_dict.items():
        journal_context += f"- {title}: {sjr}\n"
    
    # Get all criteria from scoring config
    criteria = scoring_config.get_criteria()
    
    # Build analysis prompt with dynamic criteria
    criteria_prompts = []
    field_counter = 1
    
    # Add standard fields first
    criteria_prompts.append(f"{field_counter}. disease_match: Does the article discuss {disease}? (true/false)")
    field_counter += 1
    criteria_prompts.append(f"{field_counter}. title: Article title")
    field_counter += 1
    criteria_prompts.append(f"{field_counter}. journal_title: Extract the journal name from the article and match it to the list above")
    field_counter += 1
    criteria_prompts.append(f"{field_counter}. journal_sjr: Use the SJR score from the matched journal (0 if not found)")
    field_counter += 1
    criteria_prompts.append(f"{field_counter}. year: Publication year")
    field_counter += 1
    criteria_prompts.append(f"{field_counter}. actionable_events: List which of these events are mentioned: {events_list}")
    field_counter += 1
    criteria_prompts.append(f"{field_counter}. paper_type: Type of study (Clinical Trial, Review, Case Report, etc.)")
    field_counter += 1
    criteria_prompts.append(f"{field_counter}. key_findings: Brief summary of main findings (1-2 sentences)")
    field_counter += 1
    
    # Add dynamic criteria
    for criterion in criteria:
        if criterion['type'] == 'boolean':
            criteria_prompts.append(f"{field_counter}. {criterion['name']}: {criterion['description']} (true/false)")
        elif criterion['type'] == 'numeric':
            criteria_prompts.append(f"{field_counter}. {criterion['name']}: {criterion['description']} (0-100 scale)")
        elif criterion['type'] == 'direct':
            criteria_prompts.append(f"{field_counter}. {criterion['name']}: {criterion['description']} (count)")
        field_counter += 1
    
    criteria_text = "\n    ".join(criteria_prompts)
    
    prompt = f"""Analyze these medical research articles for relevance to:
    Disease: {disease}
    Actionable Events: {', '.join(events_list)}
    
    IMPORTANT: When extracting journal information, use the following journal impact data to find the matching journal title and its SJR score:
    
    Journal Impact Data (SJR scores):
{journal_context}
    
    For each article, extract:
    {criteria_text}
    
    Return as JSON array with one object per article.
    
    Articles:
    """
    
    # Format articles for analysis - use FULL CONTENT
    articles_text = ""
    for _, article in articles_df.iterrows():
        content = article.get('content', article.get('abstract', ''))
        articles_text += f"""\n---\nPMID: {article['PMID']}
Article Content: {content[:3000]}...\n"""  # Limit content for token management
    
    response = client.models.generate_content(
        model=MODEL_ID,
        contents=[prompt + articles_text],
        config=GenerateContentConfig(
            temperature=0,
            response_mime_type="application/json",
        )
    )
    
    try:
        return json.loads(response.text)
    except:
        print("Failed to parse response:", response.text)
        return []


### 8. Two-Phase Analysis with AI.GENERATE_TABLE

Implement a two-phase approach for efficient article analysis:
- **Phase 1**: Quick event coverage check to identify relevant articles
- **Phase 2**: Full analysis of articles using AI.GENERATE_TABLE

In [None]:
def build_dynamic_schema(criteria):
    """Build dynamic BigQuery schema based on criteria configuration."""
    # Start with standard fields
    schema_parts = [
        "title STRING",
        "journal_title STRING",
        "journal_sjr FLOAT64",  # Add this line
        "year STRING", 
        "paper_type STRING",
        "actionable_events STRING"
    ]
    
    # Add fields for each criterion based on type
    for criterion in criteria:
        if criterion['name'] not in ['journal_impact', 'year']:  # Skip special ones already handled
            if criterion['type'] == 'boolean':
                schema_parts.append(f"{criterion['name']} BOOL")
            elif criterion['type'] in ['numeric', 'direct']:
                schema_parts.append(f"{criterion['name']} INT64")
    
    return ",\n    ".join(schema_parts)


def analyze_article_batch_with_criteria(df, disease, events, bq_client, journal_dict, persona, criteria):
    """Analyze articles using AI.GENERATE_TABLE directly on BigQuery table."""
    global PROJECT_ID, USER_DATASET, PUBMED_TABLE
    
    if df.empty:
        return []
    
    try:
        print(f"\n📊 Starting AI.GENERATE_TABLE analysis for {len(df)} article(s)...")
        ai_start_time = time.time()
        
        # Build journal context for Gemini to look up journal titles
        journal_context = ""
        for title, sjr in journal_dict.items():
            journal_context += f"- {title}: {sjr}\n"
        
        # Build criteria instructions
        criteria_instructions = []
        for criterion in criteria:
            if criterion['name'] not in ['journal_impact', 'year']:
                if criterion['type'] == 'boolean':
                    criteria_instructions.append(f"- {criterion['name']} (boolean): {criterion['description']}")
                elif criterion['type'] == 'numeric':
                    criteria_instructions.append(f"- {criterion['name']} (number): {criterion['description']} (Return 0 if unknown)")
                elif criterion['type'] == 'direct':
                    criteria_instructions.append(f"- {criterion['name']} (number 0-100): {criterion['description']} (Return 0 if no matches or unknown)")
        
        criteria_text = "\n".join(criteria_instructions) if criteria_instructions else ""
        
        # Build dynamic schema
        schema = build_dynamic_schema(criteria)
        
        # Build the complete prompt in Python first
        full_prompt = f"""{persona}

Analyze this article for relevance to:
Disease: {disease}
Events: {', '.join(events)}

IMPORTANT: When extracting journal information, use the following journal impact data to find the matching journal title and its SJR score:

Journal Impact Data (SJR scores):
{journal_context}

For each article, extract the following information:
1. Standard fields (always extract these):
   - title: Article title (if unknown, return empty string)
   - journal_title: Name of the journal (if unknown, return empty string)
   - journal_sjr: Use the SJR score from the matched journal (0 if not found)
   - year: Publication year as a string (e.g., "2023"). If unknown or not found, return empty string, NOT null or NaN
   - paper_type: Type of paper (e.g., clinical trial, review, case report)
   - actionable_events: Comma-separated list of events found in the article

2. Evaluation criteria:
{criteria_text}

IMPORTANT: For all numeric fields, always return 0 instead of null, NaN, or leaving the field empty.

Article content:
"""
        
        # Escape triple quotes if they appear in the prompt (unlikely but safe)
        full_prompt_escaped = full_prompt.replace('"""', '\\"""')
        
        # Get PMCIDs from dataframe (using name field as primary identifier)
        # Handle cases where PMCID might be None
        pmcids = [str(pmcid) for pmcid in df['PMCID'].tolist() if pmcid is not None]
        if not pmcids:
            print("Warning: No valid PMCIDs found in batch for full analysis")
            return []
        pmcids_str = "', '".join(pmcids)
        
        # Format schema for single line
        schema_single_line = schema.replace('\n', ' ').replace('    ', '')
        
        # Construct AI.GENERATE_TABLE query using PMCID as primary identifier
        query = f'''
        SELECT 
            PMCID,
            PMID,
            * EXCEPT (PMCID, PMID, prompt, full_response, status)
        FROM 
        AI.GENERATE_TABLE(
            MODEL `{PROJECT_ID}.{USER_DATASET}.gemini_generation`,
            (
                SELECT 
                    name AS PMCID,
                    PMID,
                    CONCAT(
                        """{full_prompt_escaped}""",
                        content
                    ) AS prompt
                FROM `{PUBMED_TABLE}`
                WHERE name IN ('{pmcids_str}')
            ),
            STRUCT(
                """{schema_single_line}""" AS output_schema,
                8192 AS max_output_tokens,
                0 AS temperature,
                0.95 AS top_p
            )
        )
        '''
        
        # Execute query
        query_execution_start = time.time()
        results_df = bq_client.query(query).to_dataframe()
        query_execution_time = time.time() - query_execution_start
        print(f"   ⚡ AI.GENERATE_TABLE query executed in {query_execution_time:.2f} seconds")
        
        # Convert to list of dictionaries and preserve article content
        processing_start = time.time()
        results = []
        for _, result_row in results_df.iterrows():
            result_dict = result_row.to_dict()
            
            # Clean up year field if it exists
            if 'year' in result_dict:
                year_val = result_dict['year']
                if year_val in [None, 'NaN', 'nan', 'null', '']:
                    result_dict['year'] = ''
                elif isinstance(year_val, str):
                    # Clean the year string
                    result_dict['year'] = year_val.strip()
            
            # Clean up all INT64 fields (numeric and direct type criteria)
            for criterion in criteria:
                if criterion['type'] in ['numeric', 'direct'] and criterion['name'] in result_dict:
                    field_value = result_dict[criterion['name']]
                    # Handle NaN, null, or invalid values
                    if pd.isna(field_value) or field_value in [None, 'NaN', 'nan', 'null', '']:
                        result_dict[criterion['name']] = 0
                    else:
                        try:
                            # Try to convert to int, default to 0 if it fails
                            result_dict[criterion['name']] = int(float(str(field_value)))
                        except (ValueError, TypeError):
                            print(f"Warning: Could not convert {criterion['name']} value '{field_value}' to int, defaulting to 0")
                            result_dict[criterion['name']] = 0
            
            # Find the corresponding content from the original df using PMCID
            matching_row = df[df['PMCID'] == result_dict.get('PMCID')]
            if not matching_row.empty:
                result_dict['content'] = matching_row.iloc[0]['content']
                # Keep PMID if available for PubMed links
                if 'PMID' in matching_row.columns:
                    result_dict['PMID'] = matching_row.iloc[0].get('PMID')
            results.append(result_dict)
        
        total_ai_time = time.time() - ai_start_time
        print(f"   ✅ Total AI.GENERATE_TABLE analysis took {total_ai_time:.2f} seconds")
        
        return results
        
    except Exception as e:
        print(f"Error in AI.GENERATE_TABLE analysis: {str(e)}")
        return []


### 9. Complete Medical Analysis Pipeline

Now let's put it all together to analyze medical cases with the two-phase approach.

In [None]:
def process_medical_case(case_text, 
                        default_articles=5,     # Articles per batch
                        min_per_event=3,        # Minimum articles per event
                        max_articles=50):       # Maximum total to search
    """
    Complete pipeline to process medical case notes with two-phase analysis.
    
    Parameters:
    - case_text: The medical case description
    - default_articles: Number of articles to retrieve per batch (default: 5)
    - min_per_event: Minimum articles required per actionable event (default: 3)
    - max_articles: Maximum total articles to search (default: 50)
    """
    
    print("🔬 Extracting medical information...")
    # Extract disease and events
    medical_info = extract_medical_info(case_text)
    disease = medical_info.get('disease', '')
    events_with_ids = medical_info.get('events_with_ids', {})
    events_list = medical_info.get('events_list', [])
    
    print(f"\n📋 Disease: {disease}")
    print(f"🧬 Actionable Events: {', '.join(events_list)}")
    
    # Phase 1: Progressive search with event coverage tracking
    print(f"\n🔍 Phase 1: Searching for articles with event coverage...")
    print(f"   Target: {min_per_event} articles per event")
    
    event_coverage = {event_id: [] for event_id in events_with_ids.keys()}
    total_articles_searched = 0
    all_articles = []
    
    while total_articles_searched < max_articles:
        # Check if all events have minimum coverage
        all_covered = all(len(pmcids) >= min_per_event for pmcids in event_coverage.values())
        if all_covered:
            print(f"✅ All events have minimum coverage!")
            break
        
        # Search next batch
        print(f"\n   Searching articles {total_articles_searched + 1}-{total_articles_searched + default_articles}...")
        
        articles_df = search_pubmed_articles(
            disease, events_list, 
            top_k=default_articles, 
            offset=total_articles_searched
        )
        
        if articles_df.empty:
            print("   No more articles found.")
            break
        
        all_articles.append(articles_df)
        
        # Quick event coverage check (simplified version)
        # In the full app, this uses AI.GENERATE_TABLE for batch processing
        for idx, row in articles_df.iterrows():
            content = row.get('content', '')
            pmcid = row.get('PMCID')
            
            # Check which events are mentioned
            for event_id, event_text in events_with_ids.items():
                if event_text.lower() in content.lower():
                    if pmcid not in event_coverage[event_id]:
                        event_coverage[event_id].append(pmcid)
        
        total_articles_searched += len(articles_df)
        
        # Report coverage
        print("\n   Event coverage status:")
        for event_id, event_text in events_with_ids.items():
            count = len(event_coverage[event_id])
            status = "✓" if count >= min_per_event else " "
            print(f"   {status} {event_text}: {count}/{min_per_event}")
    
    # Combine all articles
    if not all_articles:
        print("❌ No articles found")
        return {
            'disease': disease,
            'events': events_list,
            'articles': pd.DataFrame(),
            'case_text': case_text
        }
    
    articles_df = pd.concat(all_articles, ignore_index=True)
    print(f"\n📊 Phase 2: Analyzing {len(articles_df)} articles...")
    
    # Analyze articles ONE BY ONE for better feedback
    all_analyses = []
    
    print("\n🔄 Starting detailed analysis...")
    for idx, (_, article_row) in enumerate(articles_df.iterrows()):
        pmid = article_row.get('PMID', 'N/A')
        pmcid = article_row.get('PMCID', 'N/A')
        
        print(f"\n📄 Analyzing article {idx + 1}/{len(articles_df)}: PMID {pmid}")
        
        # Create single-article DataFrame
        single_article_df = pd.DataFrame([article_row])
        
        # Analyze this one article
        try:
            analysis_result = analyze_article_batch(single_article_df, disease, events_list, scoring_config)
            if analysis_result and len(analysis_result) > 0:
                analysis = analysis_result[0]
                all_analyses.append(analysis)
                
                # Show immediate feedback
                title = analysis.get('title', 'Unknown title')
                if len(title) > 70:
                    title = title[:67] + "..."
                print(f"   ✅ Title: {title}")
                print(f"   📚 Journal: {analysis.get('journal_title', 'Unknown')}")
                print(f"   📅 Year: {analysis.get('year', 'N/A')}")
                
                # Show which events were found
                events_found = analysis.get('actionable_events', [])
                if events_found:
                    print(f"   🎯 Events found: {', '.join(events_found)}")
            else:
                print(f"   ⚠️ No analysis results returned")
                all_analyses.append({})
        except Exception as e:
            print(f"   ❌ Error analyzing article: {str(e)}")
            all_analyses.append({})
    
    # Merge analysis with article data
    print("\n📝 Processing analysis results...")
    for idx, analysis in enumerate(all_analyses):
        if idx < len(articles_df) and analysis:
            for key, value in analysis.items():
                # Ensure the column exists before setting values
                if key not in articles_df.columns:
                    articles_df[key] = None
                    
                if key == 'actionable_events':
                    # Mark which events match the query
                    matched_events = []
                    for event in value:
                        matched = any(qe.lower() in event.lower() for qe in events_list)
                        matched_events.append({
                            'event': event,
                            'matches_query': matched
                        })
                    articles_df.at[articles_df.index[idx], key] = matched_events
                else:
                    articles_df.at[articles_df.index[idx], key] = value
    
    # Calculate scores
    print("\n🎯 Calculating scores...")
    config = scoring_config.get_config()
    scores = []
    breakdowns = []
    
    for _, article in articles_df.iterrows():
        metadata = article.to_dict()
        score, breakdown = calculate_article_score(metadata, config, disease)
        scores.append(score)
        breakdowns.append(breakdown)
    
    articles_df['score'] = scores
    articles_df['score_breakdown'] = breakdowns
    
    # Sort by score
    articles_df = articles_df.sort_values('score', ascending=False)
    
    # Show final summary with top articles
    print(f"\n✅ Analysis complete! Found {len(articles_df)} articles.")
    print(f"\n🏆 Top 3 articles by score:")
    for idx, (_, article) in enumerate(articles_df.head(3).iterrows()):
        title = article.get('title', 'Unknown')
        if len(title) > 60:
            title = title[:57] + "..."
        print(f"   {idx + 1}. Score {article['score']:.1f}: {title}")
    
    return {
        'disease': disease,
        'events': events_list,
        'articles': articles_df,
        'case_text': case_text,
        'event_coverage': event_coverage,
        'total_searched': total_articles_searched
    }

### 9. Results Visualization Functions

In [None]:
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display, HTML, Markdown

def visualize_results(results):
    """Create visualizations for the analysis results."""
    articles_df = results['articles']
    
    # Score distribution chart
    fig_scores = px.bar(
        articles_df.head(10),
        x='PMID',  
        y='score',
        title='Top 10 Articles by Score',
        labels={'PMID': 'PMID', 'score': 'Score'},
        color='score',
        color_continuous_scale='viridis'
    )
    fig_scores.update_layout(xaxis_tickangle=-45)
    fig_scores.show()
    
    # Score breakdown for top article
    if len(articles_df) > 0:
        top_article = articles_df.iloc[0]
        breakdown = top_article['score_breakdown']
        
        fig_breakdown = go.Figure(data=[
            go.Bar(
                x=list(breakdown.keys()),
                y=list(breakdown.values()),
                marker_color=['green' if v > 0 else 'red' for v in breakdown.values()]
            )
        ])
        fig_breakdown.update_layout(
            title=f"Score Breakdown for Top Article (PMID: {top_article['PMID']})",
            xaxis_title="Scoring Factor",
            yaxis_title="Points"
        )
        fig_breakdown.show()


def display_top_articles(results, n=5):
    """Display detailed information about top articles."""
    articles_df = results['articles']
    
    for idx, (_, article) in enumerate(articles_df.head(n).iterrows()):
        # Get journal name - handle both possible field names
        journal = article.get('journal_title', article.get('journal', 'Unknown'))
        
        display(HTML(f"""
        <div style="border: 2px solid #ddd; padding: 15px; margin: 10px 0; border-radius: 5px;">
            <h3>#{idx + 1} - Score: {article['score']:.1f}</h3>
            <p><strong>Title:</strong> {article.get('title', 'N/A')}</p>
            <p><strong>PMID:</strong> <a href="https://pubmed.ncbi.nlm.nih.gov/{article['PMID']}" target="_blank">{article['PMID']}</a></p>
            <p><strong>Journal:</strong> {journal} ({article.get('year', 'N/A')})</p>
            <p><strong>Key Findings:</strong> {article.get('key_findings', 'N/A')}</p>
            <details>
                <summary>Score Breakdown</summary>
                <ul>
                    {''.join([f"<li>{k}: {v}</li>" for k, v in article['score_breakdown'].items()])}
                </ul>
            </details>
        </div>
        """))

### 10. Final Analysis with Custom Prompt Templates

Generate comprehensive literature synthesis with customizable analysis prompts.

In [None]:
# Final Analysis Functions

def format_article_for_analysis(article, idx):
    """Format a single article for the analysis prompt."""
    metadata = article.get('metadata', article)
    
    # Get events found
    events_found = metadata.get('actionable_events', 'None')
    if isinstance(events_found, str) and events_found:
        events_str = events_found
    else:
        events_str = "None identified"
    
    # Handle journal info - try different fields
    journal = metadata.get('journal_title', metadata.get('journal', 'Unknown'))
    
    # Fix: Check for both uppercase and lowercase field names
    pmid = article.get('PMID') or article.get('pmid') or metadata.get('PMID') or metadata.get('pmid') or 'N/A'
    pmcid = article.get('PMCID') or article.get('pmcid') or metadata.get('PMCID') or metadata.get('pmcid') or 'N/A'
    
    return f"""
Article {idx}:
Title: {metadata.get('title', 'Unknown')}
Journal: {journal} | Year: {metadata.get('year', 'N/A')}
Type: {metadata.get('paper_type', 'Unknown')}
Score: {article.get('score', 0):.1f}
Key Concepts Found: {events_str}
PMID: {pmid} | PMCID: {pmcid}

Full Text:
{article.get('content', 'No content available')}...
"""

def create_final_analysis_prompt(case_text, disease, events, articles, custom_template):
    """Create the final analysis prompt with full article contents."""
    
    if not articles:
        return None
    
    # Format all articles
    articles_content_parts = []
    for idx, article in enumerate(articles, 1):
        articles_content_parts.append(format_article_for_analysis(article, idx))
    
    # Join all articles with separator
    articles_content = ("\n" + "="*80 + "\n").join(articles_content_parts)
    
    # Fill in the template
    filled_prompt = custom_template.format(
        case_description=case_text,
        primary_focus=disease,
        key_concepts=', '.join(events),
        articles_content=articles_content
    )
    
    return filled_prompt

In [None]:
def generate_final_analysis(results, articles_to_analyze, custom_template):
    """Generate comprehensive final analysis of the literature with visible streaming."""
    
    if not articles_to_analyze:
        return "❌ No articles available for analysis."
    
    print(f"🔄 Generating final analysis for {len(articles_to_analyze)} articles...")
    
    # Create the prompt
    prompt = create_final_analysis_prompt(
        results['case_text'],
        results['disease'],
        results['events'],
        articles_to_analyze,
        custom_template
    )
    
    if not prompt:
        return "❌ Could not create analysis prompt."
    
    # Stream the response with visible tokens
    full_response = ""
    
    try:
        print("Streaming tokens: ", end="", flush=True)
        
        for chunk in client.models.generate_content_stream(
            model=MODEL_ID,
            contents=[prompt],
            config=GenerateContentConfig(
                temperature=0.3,
                max_output_tokens=8192,
                thinking_config=types.ThinkingConfig(thinking_budget=THINKING_BUDGET)
            )
        ):
            if chunk.text:
                full_response += chunk.text
                # Show last 50 characters of the response
                display_text = full_response[-50:].replace('\n', ' ')
                print(f"\rStreaming tokens: ...{display_text}", end="", flush=True)
        
        print("\r✅ Analysis complete!                                                           ")
        return full_response
        
    except Exception as e:
        error_msg = f"Error generating response: {str(e)}"
        print(f"\n❌ {error_msg}")
        return error_msg


### 11. Interactive Medical Consultation Chat

In [None]:
def medical_qa(results, question):
    """Ask a question about the medical case analysis with streaming response."""
    # Build context from top articles
    top_articles = results['articles'].head(5)
    
    context = f"""Medical Case Context:
Disease: {results['disease']}
Actionable Events: {', '.join(results['events'])}

Top Research Articles:
"""
    
    for _, article in top_articles.iterrows():
        context += f"""
- PMID {article['PMID']}: {article['title']}
  Key findings: {article.get('key_findings', 'N/A')}
"""
    
    prompt = f"""{context}

Question: {question}

Please provide a detailed, evidence-based response. Cite specific PMIDs when referencing research.
"""
    
    # Create content for the request
    contents = [
        types.Content(
            role="user",
            parts=[types.Part(text=prompt)]
        )
    ]
    
    # Configure generation settings
    generate_content_config = types.GenerateContentConfig(
        temperature=0.3,
        top_p=0.95,
        max_output_tokens=8192,
        thinking_config=types.ThinkingConfig(
            thinking_budget=THINKING_BUDGET,
        ),
    )
    
    # Stream the response with visible tokens
    full_response = ""
    try:
        print("Streaming answer: ", end="", flush=True)
        
        for chunk in client.models.generate_content_stream(
            model=MODEL_ID,
            contents=contents,
            config=generate_content_config,
        ):
            if chunk.text:
                full_response += chunk.text
                # Show last 50 characters of the response
                display_text = full_response[-50:].replace('\n', ' ')
                print(f"\rStreaming answer: ...{display_text}", end="", flush=True)
        
        print("\r✅ Answer complete!                                                           ")
        return full_response
        
    except Exception as e:
        error_msg = f"Error generating response: {str(e)}"
        print(f"\n❌ {error_msg}")
        return error_msg


## Complete Example: Analyzing a Medical Case

Let's analyze a sample pediatric leukemia case through the complete pipeline.

In [None]:
# Sample case notes
SAMPLE_CASE = """
A 4-year-old male presents with a 3-week history of progressive fatigue, pallor, and easy bruising. 
Physical examination reveals hepatosplenomegaly and scattered petechiae. 

Laboratory findings:
- WBC: 45,000/μL with 80% blasts
- Hemoglobin: 7.2 g/dL
- Platelets: 32,000/μL

Flow cytometry: CD33+, CD13+, CD117+, CD34+, HLA-DR+, CD19-, CD3-

Cytogenetics: 46,XY,t(9;11)(p21.3;q23.3)
Molecular: KMT2A-MLLT3 fusion detected, FLT3-ITD positive, NRAS G12D mutation

Diagnosis: KMT2A-rearranged acute myeloid leukemia (AML)
"""

print("📋 Sample Case Notes:")
print(SAMPLE_CASE)

### Step 1: Run the Analysis Pipeline

In [None]:
# Extract Medical Information from Case
print("🔬 Extracting medical information from case notes...")
medical_info = extract_medical_info(SAMPLE_CASE)

# Display extracted information
print("\n📋 Extracted Information:")
print("=" * 60)
print(f"Disease: {medical_info.get('disease', '')}")
print(f"\nActionable Events:")
for i, event in enumerate(medical_info.get('events_list', []), 1):
    print(f"  {i}. {event}")
print("=" * 60)

# Allow users to modify extracted information if needed
# You can edit these values before running the analysis
DISEASE = medical_info.get('disease', '') # @param {type: "string"}
EVENTS = medical_info.get('events_list', []) # @param {type: "raw"}

# Convert events to list if user modified it as string
if isinstance(EVENTS, str):
    EVENTS = [e.strip() for e in EVENTS.split(',') if e.strip()]

# Display final values that will be used
print("\n✅ Values to be used for analysis:")
print(f"Disease: {DISEASE}")
print(f"Events: {EVENTS}")

In [None]:
# Analysis Configuration Parameters
DEFAULT_ARTICLES = 2 # @param {type: "slider", min: 1, max: 20, step: 1}
MIN_ARTICLES_PER_EVENT = 1 # @param {type: "slider", min: 1, max: 10, step: 1}
MAX_ARTICLES_TO_SEARCH = 3 # @param {type: "slider", min: 10, max: 100, step: 5}

# Define your custom scoring criteria
# Each criterion should have:
# - name: unique identifier (will be used as field name)
# - description: what to look for in the article
# - type: 'boolean' (true/false), 'numeric' (0-100), 'direct' (direct count), or 'special' (computed)
# - weight: points to assign (can be negative for penalties)

CUSTOM_CRITERIA = [
    # Special criteria (these have custom computation logic)
    {"name": "journal_impact", "description": "High-impact journal (automatic SJR lookup)", "type": "special", "weight": 25},
    {"name": "year_penalty", "description": "Penalty per year old", "type": "special", "weight": -5},
    {"name": "event_match", "description": "Points per matching event", "type": "special", "weight": 15},
    
    # Quality Factors
    {"name": "novelty", "description": "Presents novel/innovative findings or approaches", "type": "boolean", "weight": 10},
    
    # Relevance Factors
    {"name": "disease_match", "description": "Discusses the specific disease from the case", "type": "boolean", "weight": 70},
    {"name": "pediatric_focus", "description": "Focuses on pediatric patients", "type": "boolean", "weight": 50},
    {"name": "treatment_shown", "description": "Shows treatment efficacy or outcomes", "type": "boolean", "weight": 80},
    {"name": "drugs_tested", "description": "Tests or discusses specific drugs/therapies", "type": "boolean", "weight": 5},
    
    # Study Types
    {"name": "clinical_trial", "description": "Is a clinical trial", "type": "boolean", "weight": 50},
    {"name": "review_article", "description": "Is a review article", "type": "boolean", "weight": -5},
    {"name": "case_report", "description": "Is a case report", "type": "boolean", "weight": 5},
    {"name": "case_series", "description": "Is a case series or series of case reports", "type": "boolean", "weight": 10},
    {"name": "cell_studies", "description": "Includes cell/in-vitro studies", "type": "boolean", "weight": 5},
    {"name": "animal_studies", "description": "Includes animal/mouse model studies", "type": "boolean", "weight": 10},
    {"name": "clinical_study", "description": "Is a clinical study (observational or interventional)", "type": "boolean", "weight": 15},
    {"name": "clinical_study_on_children", "description": "Is a clinical study specifically on children", "type": "boolean", "weight": 20},
    
    # Add your own criteria here! Examples:
    # {"name": "biomarker_analysis", "description": "Analyzes biomarkers or genetic markers", "type": "boolean", "weight": 15},
    # {"name": "survival_data", "description": "Includes survival or outcome data", "type": "boolean", "weight": 25},
    # {"name": "side_effects", "description": "Discusses treatment side effects or toxicity", "type": "boolean", "weight": 10},
    # {"name": "sample_size", "description": "Sample size (0-100 scale where 100 = very large study)", "type": "numeric", "weight": 0.3},
]

# Initialize dynamic scoring configuration
scoring_config = DynamicScoringConfig(CUSTOM_CRITERIA)

# Display current configuration
print("📊 Analysis Configuration")
print("=" * 60)
print(f"Articles per batch: {DEFAULT_ARTICLES}")
print(f"Minimum articles per event: {MIN_ARTICLES_PER_EVENT}")
print(f"Maximum articles to search: {MAX_ARTICLES_TO_SEARCH}")
print("\n📏 Scoring Criteria:")
print("-" * 60)
print(f"{'Criterion':<30} {'Type':<10} {'Weight':<10} {'Description'}")
print("-" * 60)
for criterion in CUSTOM_CRITERIA:
    desc = criterion['description'][:40] + '...' if len(criterion['description']) > 40 else criterion['description']
    print(f"{criterion['name']:<30} {criterion['type']:<10} {criterion['weight']:<10} {desc}")
print("=" * 60)

# Process the case with configurable parameters
print("\n🚀 Starting analysis with custom criteria...")
results = process_medical_case(
    SAMPLE_CASE, 
    default_articles=DEFAULT_ARTICLES,
    min_per_event=MIN_ARTICLES_PER_EVENT,
    max_articles=MAX_ARTICLES_TO_SEARCH
)

### Step 2: Visualize Results

In [None]:
# Visualize the results
visualize_results(results)

# Display top articles
display_top_articles(results, n=5)

### Step 3: Generate Final Literature Analysis

Select articles and customize the analysis prompt to generate a comprehensive literature synthesis.

In [None]:
# Step 3: Generate Final Literature Analysis

# Set your analysis prompt - customize this to get the type of analysis you want
ANALYSIS_PROMPT = """You are a research analyst synthesizing findings from a comprehensive literature review. Your goal is to provide insights that are valuable for research purposes.

RESEARCH CONTEXT:
Original Query/Case: {case_description}

Primary Focus: {primary_focus}
Key Concepts Searched: {key_concepts}

ANALYZED ARTICLES:
{articles_content}

Based on the research context and analyzed articles above, please provide a comprehensive synthesis in markdown format with the following sections:

## Literature Analysis: {primary_focus}

### 1. Executive Summary
Provide a concise overview of the key findings from the literature review, highlighting:
- Main themes identified across the literature
- Most significant insights relevant to the research query  
- Overall quality and quantity of available evidence
- Key takeaways for researchers in this field

### 2. Key Findings by Concept
| Concept | Articles Discussing | Key Findings | Evidence Quality |
|---------|-------------------|--------------|------------------|
[For each key concept searched, summarize what the literature reveals about it. In "Articles Discussing", list articles using their PMCID as clickable links, e.g., [PMC7654321](https://pmc.ncbi.nlm.nih.gov/articles/PMC7654321/)]

### 3. Methodological Landscape
| Research Method | Frequency | Notable Studies | Insights Generated |
|-----------------|-----------|-----------------|-------------------|
[Map the research methodologies used across the analyzed articles. Reference studies by PMCID]

### 4. Temporal Trends
| Time Period | Research Focus | Key Developments | Paradigm Shifts |
|-------------|----------------|------------------|-----------------|
[Analyze how research in this area has evolved over time. Cite articles using PMCID]

### 5. Cross-Study Patterns
| Pattern | Supporting Evidence | Implications | Confidence Level |
|---------|-------------------|--------------|------------------|
[Identify patterns that appear across multiple studies. List supporting evidence with PMCID references]

### 6. Controversies & Unresolved Questions
| Issue | Different Perspectives | Evidence For/Against | Current Consensus |
|-------|----------------------|---------------------|-------------------|
[Highlight areas of disagreement or ongoing debate in the literature. Cite specific articles by PMCID]

### 7. Knowledge Gaps & Future Research
| Gap Identified | Why It Matters | Potential Approaches | Expected Impact |
|----------------|----------------|---------------------|-----------------|
[Map areas where further research is needed based on the analyzed articles]

### 8. Practical Applications
Based on the synthesized literature, identify:
- How these findings can be applied in practice
- Recommendations for researchers entering this field
- Tools, methods, or frameworks that emerge from the literature
- Potential interdisciplinary connections

### 9. Quality & Reliability Assessment
Evaluate the overall body of literature:
- **Study Types**: Distribution of research designs (experimental, observational, reviews, etc.)
- **Sample Characteristics**: Common sample sizes, populations studied
- **Geographic Distribution**: Where research is being conducted
- **Publication Patterns**: Journal quality, publication years, citation patterns
- **Methodological Rigor**: Strengths and limitations observed

### 10. Synthesis & Conclusions
Provide an integrated narrative that:
- Connects findings across all analyzed articles
- Identifies the strongest evidence and most reliable findings
- Suggests how this research area is likely to develop
- Offers guidance for stakeholders interested in this topic

### 11. Bibliography
**Most Relevant Articles** (in order of relevance to the research query):
[For each article, format as follows:
- Title, Journal (Year). [PMCID: PMCxxxxxx](https://pmc.ncbi.nlm.nih.gov/articles/PMCxxxxxx/) | [PMID: xxxxxxxx](https://pubmed.ncbi.nlm.nih.gov/xxxxxxxx/)]

---

IMPORTANT NOTES:
- When referencing articles throughout the analysis, ALWAYS use their PMCID or PMID identifiers, not generic labels like "Article 1"
- Format all article references as clickable links: [PMCxxxxxx](https://pmc.ncbi.nlm.nih.gov/articles/PMCxxxxxx/)
- Maintain objectivity and clearly distinguish between strong evidence and preliminary findings
- Use accessible language while preserving scientific accuracy
- All claims must be traceable to specific articles in the analysis
- When evidence is conflicting, present all viewpoints fairly
- Focus on research insights and knowledge synthesis rather than prescriptive recommendations
- Highlight both the strengths and limitations of the current literature
""" # @param {type: "string"}

# Generate the analysis
num_articles = len(results['articles'])

if num_articles == 0:
    print("❌ No articles available for analysis.")
else:
    # Use all articles from results
    all_articles = results['articles'].to_dict('records')
    
    print(f"📊 Analyzing all {num_articles} retrieved articles...")
    
    # Generate analysis with streaming indicator
    final_analysis = generate_final_analysis(
        results, 
        all_articles, 
        ANALYSIS_PROMPT
    )
    
    # Display formatted result
    from IPython.display import display, Markdown
    display(Markdown("## 📊 Final Literature Analysis"))
    display(Markdown(final_analysis))


### Step 4: Interactive Medical Consultation

In [None]:
# Step 4: Interactive Medical Consultation

# Ask a question about the medical case
MEDICAL_QUESTION = "What is the prognosis for this specific KMT2A rearrangement?" # @param {type: "string"}

# Example questions to try:
# - "What is the prognosis for this specific KMT2A rearrangement?"
# - "What are the key monitoring parameters during treatment?"
# - "How does the NRAS mutation affect treatment selection?"
# - "What combination therapies have shown promise for this disease profile?"

if MEDICAL_QUESTION:
    from IPython.display import display, Markdown
    
    # Display the question
    display(Markdown(f"### 💬 Q: {MEDICAL_QUESTION}"))
    
    # Generate answer with streaming
    print(f"🤔 Analyzing question based on {len(results['articles'])} articles...")
    answer = medical_qa(results, MEDICAL_QUESTION)
    
    # Display formatted answer
    display(Markdown("### 💡 Answer:"))
    display(Markdown(answer))
else:
    print("💡 Enter a question above to get an evidence-based answer from the analyzed literature.")

## Cleaning up

To avoid incurring charges to your Google Cloud account for the resources used in this notebook, follow these steps:

1. To avoid unnecessary Google Cloud charges, use the [Google Cloud console](https://console.cloud.google.com/) to delete your project if you do not need it. Learn more in the Google Cloud documentation for [managing and deleting your project](https://cloud.google.com/resource-manager/docs/creating-managing-projects).
2. Disable the [Vertex AI API](https://console.cloud.google.com/apis/api/aiplatform.googleapis.com) in the Google Cloud Console.