In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# PubMed Medical Literature Analysis

<!-- [PLACEHOLDER: Update these links when notebook is finalized] -->
<table style="float: left; margin-right: 20px;">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Run in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FWandLZhang%2Fpubmed-rag%2Fmain%2FPubMed_RAG_Example.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/WandLZhang/pubmed-rag/main/PubMed_RAG_Example.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<!-- [PLACEHOLDER: Update share links when notebook is finalized] -->
<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/WandLZhang/pubmed-rag/blob/main/PubMed_RAG_Example.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>            

| Authors |
| --- |
| [Willis Zhang](https://github.com/WandLZhang) |
| [Stone Jiang](https://github.com/siduojiang) |

## Overview

**Blog Post: Medical Literature Analysis with PubMed, BigQuery, Gemini**

<a href="[blog-post-url-placeholder]" target="_blank">
  <img src="https://storage.googleapis.com/[placeholder-image-path]/medical-literature-blog-header.jpg" alt="Medical Literature Analysis with PubMed and Gemini" width="500">
</a>

This notebook demonstrates how to analyze medical cases using PubMed literature with BigQuery vector search and Gemini. It converts the basic user experience from the [Capricorn Medical Research Application](https://capricorn-medical-research.web.app/) into an interactive Colab notebook.


In this tutorial, you learn how to:

- Extract medical information (disease diagnosis and actionable events) from case notes
- Search PubMed literature using BigQuery vector search
- Score and rank articles using customizable criteria
- Generate evidence-based treatment recommendations with citations
- Create an interactive chat interface for medical discussions

![Medical Literature Analysis Architecture](https://github.com/WandLZhang/pubmed-rag/blob/main/visuals/1.png?raw=true)

This tutorial uses the following Google Cloud AI services and resources:

- **Vertex AI**: Gemini 2.5 Flash for text analysis and generation
- **BigQuery**: Vector search on PubMed article embeddings
- **Interactive Widgets**: Customizable scoring configuration

## Let's begin

1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.

2. After selecting your project, in the console click the main logo in the top-left to get to your project home.

![](https://github.com/WandLZhang/pubmed-rag/blob/main/visuals/2.png?raw=true)

This will take you to your project home. Copy the `Project ID` like the above orange box and paste it into the field below:

In [None]:
import os

PROJECT_ID = "[your-project-id]"  # @param {type: "string"}
if not PROJECT_ID or PROJECT_ID == "[your-project-id]":
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

NOTE: You can change the `LOCATION` variable used by Vertex AI. Learn more about [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations).

2. Enable the [Vertex AI APIs](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com,discoveryengine.googleapis.com).
3. If you are running this notebook locally, you need to install the [Cloud SDK](https://cloud.google.com/sdk).
4. Install the following packages required to execute this notebook.

In [None]:
%pip install --upgrade --quiet google-genai google-cloud-bigquery ipywidgets plotly pandas==2.2.2

5. If you are running this notebook on Google Colab, you will need to authenticate your environment. To do this, run the new cell below. This step is not required if you are using Vertex AI Workbench.

In [None]:
import sys

if "google.colab" in sys.modules:
    # Authenticate user to Google Cloud
    from google.colab import auth

    auth.authenticate_user()

## Medical Literature Analysis Pipeline

This section implements the complete medical literature analysis workflow, from case notes to treatment recommendations.

### 1. Initialize Vertex AI, BigQuery, and Dataset Configuration

In [None]:
# Initialize the Gemini model from Vertex AI:
from google import genai

client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)
MODEL_ID = "gemini-2.5-flash" # @param ["gemini-2.5-flash-lite","gemini-2.5-flash","gemini-2.5-pro","gemini-2.0-flash"] {"allow-input":true, isTemplate: true}

# Initialize BigQuery client
from google.cloud import bigquery
bq_client = bigquery.Client(project=PROJECT_ID)

# Configure PubMed dataset (public dataset with embeddings)
# [PLACEHOLDER: Update to final public dataset location]
PUBMED_DATASET = "wz-data-catalog-demo.pubmed"
PUBMED_TABLE = f"{PUBMED_DATASET}.pmid_embed_nonzero_metadata"  # Combined embeddings and metadata table

# User's BigQuery dataset for embedding model
USER_DATASET = "pubmed"  # @param {type: "string"}
EMBEDDING_MODEL = f"{PROJECT_ID}.{USER_DATASET}.textembed"  # Text embedding model for vector search

# Create the dataset if it doesn't exist
try:
    # Check if dataset exists
    dataset_ref = bq_client.dataset(USER_DATASET)
    dataset = bq_client.get_dataset(dataset_ref)
    print(f"✅ Dataset '{USER_DATASET}' already exists")
except:
    # Create the dataset
    dataset = bigquery.Dataset(f"{PROJECT_ID}.{USER_DATASET}")
    dataset.location = LOCATION
    dataset = bq_client.create_dataset(dataset, exists_ok=True)
    print(f"✅ Created dataset '{USER_DATASET}'")

# Journal impact data will be loaded from CSV [PLACEHOLDER]
JOURNAL_IMPACT_CSV_URL = "https://raw.githubusercontent.com/WandLZhang/scimagojr_2024/main/scimagojr_2024.csv"

### 2. Create Text Embedding Model (if needed)

Before running vector searches, you need to ensure you have a text embedding model in BigQuery. Run the following query in the BigQuery console if you haven't created this model yet:

```sql
CREATE MODEL `your-project-id.your-dataset.textembed`
  REMOTE WITH CONNECTION DEFAULT
  OPTIONS(endpoint="text-embedding-005");
```

Replace `your-project-id` and `your-dataset` with your actual project ID and dataset name.

In [None]:
# Check if the embedding model exists
try:
    # Try to get model information
    query = f"SELECT * FROM `{EMBEDDING_MODEL}`.INFORMATION_SCHEMA.MODEL_OPTIONS LIMIT 1"
    result = bq_client.query(query).result()
    print(f"✅ Embedding model found: {EMBEDDING_MODEL}")
except Exception as e:
    print(f"❌ Embedding model not found: {EMBEDDING_MODEL}")
    print(f"\nPlease create the model by running this query in BigQuery:")
    print(f"\n" + "="*60)
    print(f"CREATE MODEL `{EMBEDDING_MODEL}`")
    print(f"  REMOTE WITH CONNECTION DEFAULT")
    print(f"  OPTIONS(endpoint='text-embedding-005');")
    print("="*60)
    print(f"\nNote: You can run this query directly in the BigQuery console or in the cell below.")

### 3. Load Journal Impact Data

In [None]:
# Load journal impact data from CSV
# [PLACEHOLDER: Update URL when repository is finalized]
import pandas as pd
journal_impact_df = pd.read_csv(JOURNAL_IMPACT_CSV_URL)

# Convert to dictionary for faster lookups
journal_impact_dict = dict(zip(journal_impact_df['Title'], journal_impact_df['SJR']))
print(f"Loaded {len(journal_impact_dict)} journal impact records")

### 3. Customizable Scoring System

Configure how articles are scored based on various factors. Adjust the sliders to match your research priorities.

In [None]:
# Create dynamic scoring configuration interface
import ipywidgets as widgets

class DynamicScoringConfig:
    def __init__(self):
        self.categories = {}
        self.widgets_list = []
        self.output = widgets.Output()
        
        # Default categories
        self.default_categories = {
            'journal_impact': {'desc': 'Journal Impact Score', 'min': 0, 'max': 50, 'default': 25, 'type': 'quality'},
            'year_penalty': {'desc': 'Year Penalty (per year)', 'min': -20, 'max': 0, 'default': -5, 'type': 'quality'},
            'disease_match': {'desc': 'Disease Match', 'min': 0, 'max': 100, 'default': 50, 'type': 'relevance'},
            'pediatric_focus': {'desc': 'Pediatric Focus', 'min': 0, 'max': 50, 'default': 20, 'type': 'relevance'},
            'event_match': {'desc': 'Per Event Match', 'min': 0, 'max': 30, 'default': 15, 'type': 'relevance'},
            'treatment_shown': {'desc': 'Treatment Efficacy', 'min': 0, 'max': 100, 'default': 50, 'type': 'relevance'},
            'clinical_trial': {'desc': 'Clinical Trial', 'min': 0, 'max': 50, 'default': 40, 'type': 'study_type'},
            'review_article': {'desc': 'Review Article', 'min': -10, 'max': 10, 'default': -5, 'type': 'study_type'},
            'case_report': {'desc': 'Case Report', 'min': 0, 'max': 20, 'default': 5, 'type': 'study_type'},
            'case_series': {'desc': 'Case Series', 'min': 0, 'max': 20, 'default': 10, 'type': 'study_type'},
            'cell_studies': {'desc': 'Cell/In-vitro Studies', 'min': 0, 'max': 10, 'default': 5, 'type': 'study_type'},
            'animal_studies': {'desc': 'Animal/Mouse Studies', 'min': 0, 'max': 15, 'default': 10, 'type': 'study_type'},
            'clinical_study': {'desc': 'Clinical Study', 'min': 0, 'max': 30, 'default': 15, 'type': 'study_type'},
            'clinical_study_on_children': {'desc': 'Clinical Study on Children', 'min': 0, 'max': 40, 'default': 20, 'type': 'study_type'},
            'drugs_tested': {'desc': 'Drugs Tested', 'min': 0, 'max': 10, 'default': 5, 'type': 'relevance'},
            'novelty': {'desc': 'Novel Findings', 'min': 0, 'max': 30, 'default': 10, 'type': 'quality'}
        }
        
        # Initialize with default categories
        for key, config in self.default_categories.items():
            self.add_category(key, config['desc'], config['min'], config['max'], config['default'], config['type'])
        
        # Add category controls
        self.category_name = widgets.Text(
            description='Category ID:', 
            placeholder='e.g., my_custom_score',
            style={'description_width': '100px'}
        )
        self.category_desc = widgets.Text(
            description='Description:', 
            placeholder='e.g., My Custom Score',
            style={'description_width': '100px'}
        )
        self.category_min = widgets.IntText(
            value=-10, description='Min:', 
            style={'description_width': '40px'},
            layout={'width': '150px'}
        )
        self.category_max = widgets.IntText(
            value=50, description='Max:', 
            style={'description_width': '40px'},
            layout={'width': '150px'}
        )
        self.category_default = widgets.IntText(
            value=10, description='Default:', 
            style={'description_width': '60px'},
            layout={'width': '150px'}
        )
        self.category_type = widgets.Dropdown(
            options=['quality', 'relevance', 'study_type', 'custom'],
            value='custom',
            description='Type:',
            style={'description_width': '60px'}
        )
        
        self.add_button = widgets.Button(
            description='Add Category',
            button_style='success',
            icon='plus'
        )
        self.add_button.on_click(self._add_category_click)
        
        # Presets
        self.presets = widgets.Dropdown(
            options=['Custom', 'Clinical Focus', 'Research Focus', 'Pediatric Focus', 'Recent Evidence'],
            value='Custom',
            description='Presets:',
            style={'description_width': '100px'}
        )
        self.presets.observe(self.load_preset, names='value')
        
    def add_category(self, key, description, min_val, max_val, default_val, cat_type='custom'):
        """Add a scoring category."""
        slider = widgets.IntSlider(
            value=default_val, min=min_val, max=max_val, step=1,
            description=description + ':',
            style={'description_width': '200px'},
            layout={'width': '500px'}
        )
        
        remove_button = widgets.Button(
            description='',
            button_style='danger',
            icon='trash',
            layout={'width': '40px'}
        )
        remove_button.on_click(lambda b: self.remove_category(key))
        
        self.categories[key] = {
            'slider': slider,
            'type': cat_type,
            'remove_button': remove_button,
            'description': description
        }
        
    def remove_category(self, key):
        """Remove a scoring category."""
        if key in self.categories:
            del self.categories[key]
            self._refresh_display()
    
    def _add_category_click(self, b):
        """Handle add category button click."""
        key = self.category_name.value.strip().lower().replace(' ', '_')
        if key and key not in self.categories:
            self.add_category(
                key,
                self.category_desc.value or key,
                self.category_min.value,
                self.category_max.value,
                self.category_default.value,
                self.category_type.value
            )
            # Clear inputs
            self.category_name.value = ''
            self.category_desc.value = ''
            self._refresh_display()
    
    def load_preset(self, change):
        """Load preset configurations."""
        preset = change['new']
        
        if preset == 'Clinical Focus':
            for key in ['clinical_trial', 'treatment_shown']:
                if key in self.categories:
                    self.categories[key]['slider'].value = {'clinical_trial': 50, 'treatment_shown': 80}[key]
                    
        elif preset == 'Research Focus':
            for key, val in {'novelty': 30, 'clinical_trial': 30}.items():
                if key in self.categories:
                    self.categories[key]['slider'].value = val
                    
        elif preset == 'Pediatric Focus':
            for key, val in {'pediatric_focus': 50, 'disease_match': 70, 'clinical_trial': 40}.items():
                if key in self.categories:
                    self.categories[key]['slider'].value = val
                    
        elif preset == 'Recent Evidence':
            for key, val in {'year_penalty': -10, 'journal_impact': 35, 'novelty': 20}.items():
                if key in self.categories:
                    self.categories[key]['slider'].value = val
    
    def get_config(self):
        """Return current configuration as dictionary."""
        config = {}
        for key, cat in self.categories.items():
            config[key] = cat['slider'].value
        return config
    
    def get_category_types(self):
        """Return mapping of categories to their types."""
        return {key: cat['type'] for key, cat in self.categories.items()}
    
    def _refresh_display(self):
        """Refresh the display."""
        self.display()
    
    def display(self):
        """Display the configuration interface."""
        # Group categories by type
        grouped = {'quality': [], 'relevance': [], 'study_type': [], 'custom': []}
        
        for key, cat in self.categories.items():
            cat_type = cat['type']
            if cat_type in grouped:
                row = widgets.HBox([cat['slider'], cat['remove_button']])
                grouped[cat_type].append(row)
        
        # Create tabs
        tabs_content = []
        tab_names = []
        
        type_names = {
            'quality': 'Quality Factors',
            'relevance': 'Relevance Factors', 
            'study_type': 'Study Types',
            'custom': 'Custom Categories'
        }
        
        for cat_type, name in type_names.items():
            if grouped[cat_type]:
                content = widgets.VBox([
                    widgets.HTML(f'<h4>{name}</h4>'),
                    *grouped[cat_type]
                ])
                tabs_content.append(content)
                tab_names.append(name.split()[0])
        
        tabs = widgets.Tab(children=tabs_content)
        for i, name in enumerate(tab_names):
            tabs.set_title(i, name)
        
        # Add category section
        add_section = widgets.VBox([
            widgets.HTML('<h4>Add Custom Category</h4>'),
            widgets.HBox([
                self.category_name,
                self.category_desc
            ]),
            widgets.HBox([
                self.category_min,
                self.category_max,
                self.category_default,
                self.category_type
            ]),
            self.add_button
        ], layout={'border': '1px solid #ddd', 'padding': '10px', 'margin': '10px 0'})
        
        # Clear output and display
        clear_output(wait=True)
        display(widgets.VBox([
            widgets.HTML('<h3>📊 Dynamic Scoring Configuration</h3>'),
            self.presets,
            tabs,
            add_section
        ]))

# Import for clear_output
from IPython.display import clear_output

# Initialize dynamic scoring configuration
scoring_config = DynamicScoringConfig()
scoring_config.display()

### 4. Medical Information Extraction Functions

In [None]:
from google.genai.types import GenerateContentConfig

def extract_medical_info(case_text, info_type="both"):
    """Extract disease and actionable events from case notes."""
    
    prompts = {
        "disease": """Extract the primary disease diagnosis from these case notes. 
        Return ONLY the disease name, nothing else. For example: 'KMT2A-rearranged AML'""",
        
        "events": """Extract all actionable medical events from these case notes including:
        - Genetic mutations (e.g., NRAS G12D, FLT3-ITD)
        - Chromosomal abnormalities (e.g., t(9;11))
        - Biomarkers (e.g., CD33+, CD19-)
        - Other molecular findings
        
        Return ONLY a comma-separated list of events, nothing else.
        Example: "NRAS G12D, FLT3-ITD, CD33+, t(9;11)"""
    }
    
    results = {}
    
    for key, prompt in prompts.items():
        if info_type == "both" or info_type == key:
            full_prompt = f"{prompt}\n\nCase notes:\n{case_text}"
            
            response = client.models.generate_content(
                model=MODEL_ID,
                contents=[full_prompt],
                config=GenerateContentConfig(
                    temperature=0,
                    max_output_tokens=200,
                )
            )
            
            results[key] = response.text.strip()
    
    return results

### 5. Article Scoring Functions

In [None]:
import math
from datetime import datetime

def normalize_journal_score(sjr, max_points):
    """Normalize journal SJR score to specified max points."""
    if not sjr or sjr <= 0:
        return 0
    
    # Use log scale to handle large range of SJR values
    # Typical SJR ranges from 0 to ~100,000
    normalized = math.log(sjr + 1) * (max_points / 12)  # log(100000) ≈ 11.5
    return min(normalized, max_points)

def calculate_article_score(metadata, config, query_disease=None):
    """Calculate article score based on metadata and user configuration."""
    score = 0
    breakdown = {}
    
    # Journal impact - use dynamic max points from config
    if metadata.get('journal_sjr'):
        sjr = float(metadata['journal_sjr'])
        if sjr > 0:
            # Use the user-configured max points for journal impact
            max_points = config.get('journal_impact', 25)  # Default to 25 if not set
            impact_points = normalize_journal_score(sjr, max_points)
            score += impact_points
            breakdown['journal_impact'] = round(impact_points, 2)
    
    # Year penalty
    current_year = datetime.now().year
    if metadata.get('year'):
        try:
            article_year = int(metadata['year'])
            year_diff = current_year - article_year
            year_points = config['year_penalty'] * year_diff
            score += year_points
            breakdown['year'] = year_points
        except:
            pass
    
    # Disease match
    if metadata.get('disease_match'):
        score += config['disease_match']
        breakdown['disease_match'] = config['disease_match']
    
    # Pediatric focus
    if metadata.get('pediatric_focus'):
        score += config['pediatric_focus']
        breakdown['pediatric_focus'] = config['pediatric_focus']
    
    # Actionable events
    events = metadata.get('actionable_events', [])
    matched_events = sum(1 for event in events if event.get('matches_query', False))
    if matched_events > 0:
        event_points = matched_events * config['event_match']
        score += event_points
        breakdown['actionable_events'] = event_points
    
    # Treatment shown
    if metadata.get('treatment_shown'):
        score += config['treatment_shown']
        breakdown['treatment_shown'] = config['treatment_shown']
    
    # Paper type
    paper_type = metadata.get('paper_type', '').lower()
    if 'clinical trial' in paper_type:
        score += config['clinical_trial']
        breakdown['paper_type'] = config['clinical_trial']
    elif 'review' in paper_type:
        score += config['review_article']
        breakdown['paper_type'] = config['review_article']
    
    # Other study types
    if metadata.get('case_report'):
        score += config.get('case_report', 5)
        breakdown['case_report'] = config.get('case_report', 5)
    
    if metadata.get('series_of_case_reports'):
        score += config.get('case_series', 10)
        breakdown['case_series'] = config.get('case_series', 10)
    
    if metadata.get('cell_studies'):
        score += config.get('cell_studies', 5)
        breakdown['cell_studies'] = config.get('cell_studies', 5)
    
    if metadata.get('mice_studies'):
        score += config.get('animal_studies', 10)
        breakdown['animal_studies'] = config.get('animal_studies', 10)
    
    if metadata.get('clinical_study'):
        score += config.get('clinical_study', 15)
        breakdown['clinical_study'] = config.get('clinical_study', 15)
    
    if metadata.get('clinical_study_on_children'):
        score += config.get('clinical_study_on_children', 20)
        breakdown['clinical_study_on_children'] = config.get('clinical_study_on_children', 20)
    
    if metadata.get('drugs_tested'):
        score += config.get('drugs_tested', 5)
        breakdown['drugs_tested'] = config.get('drugs_tested', 5)
    
    # Novelty - check both 'novelty' and 'novel_findings' fields
    if metadata.get('novelty') or metadata.get('novel_findings'):
        score += config.get('novelty', 10)
        breakdown['novelty'] = config.get('novelty', 10)
    
    return round(score, 2), breakdown

### 6. BigQuery Vector Search Functions

In [None]:
def generate_embedding(text):
    """Generate text embedding using Gemini."""
    from vertexai.language_models import TextEmbeddingModel
    
    model = TextEmbeddingModel.from_pretrained("text-embedding-005")
    embeddings = model.get_embeddings([text])
    return embeddings[0].values

def search_pubmed_articles(disease, events_list, top_k=15):
    """Search PubMed articles using BigQuery vector similarity."""
    
    # Combine disease and events for search query
    query_text = f"{disease} {' '.join(events_list)}"
    
    # Create the SQL query using VECTOR_SEARCH and ML.GENERATE_EMBEDDING
    sql = f"""
    DECLARE query_text STRING;
    SET query_text = \"\"\"{query_text}\"\"\";
    
    WITH query_embedding AS (
      SELECT ml_generate_embedding_result AS embedding_col
      FROM ML.GENERATE_EMBEDDING(
        MODEL `{EMBEDDING_MODEL}`,
        (SELECT query_text AS content),
        STRUCT(TRUE AS flatten_json_output)
      )
    )
    SELECT
      base.AccessionID as PMCID,
      base.PMID,
      base.content,  -- Full article text
      base.text as abstract,
      distance
    FROM VECTOR_SEARCH(
      TABLE `{PUBMED_TABLE}`,
      'ml_generate_embedding_result',
      (SELECT embedding_col FROM query_embedding),
      top_k => {top_k}
    )
    ORDER BY distance ASC;
    """
    
    # Execute query
    results = bq_client.query(sql).to_dataframe()
    
    # We'll need to extract journal info from the content later during analysis
    # For now, return the raw results
    return results

### 7. Article Analysis with Custom Categories

In [None]:
import json

def analyze_article_batch(articles_df, disease, events_list, scoring_config):
    """Analyze a batch of articles using Gemini."""
    
    # Build journal context for Gemini to look up journal titles
    journal_context = ""
    for title, sjr in journal_impact_dict.items():
        journal_context += f"- {title}: {sjr}\n"
    
    # Get custom categories from scoring config
    custom_categories = [key for key, cat in scoring_config.categories.items() 
                        if key not in scoring_config.default_categories]
    
    # Build analysis prompt with custom categories
    custom_cat_prompts = ""
    if custom_categories:
        custom_cat_prompts = "\nCUSTOM SCORING CATEGORIES:\n"
        for cat_key in custom_categories:
            cat_info = scoring_config.categories[cat_key]
            custom_cat_prompts += f"- {cat_key}: {cat_info['description']} (score true/false)\n"
    
    prompt = f"""Analyze these medical research articles for relevance to:
    Disease: {disease}
    Actionable Events: {', '.join(events_list)}
    
    IMPORTANT: When extracting journal information, use the following journal impact data to find the matching journal title and its SJR score:
    
    Journal Impact Data (SJR scores):
{journal_context}
    
    For each article, extract:
    1. disease_match: Does the article discuss {disease}? (true/false)
    2. title: Article title
    3. journal_title: Extract the journal name from the article and match it to the list above
    4. journal_sjr: Use the SJR score from the matched journal (0 if not found)
    5. year: Publication year
    6. pediatric_focus: Is this about pediatric patients? (true/false)
    7. actionable_events: List which of these events are mentioned: {events_list}
    8. treatment_shown: Does it show treatment efficacy? (true/false)
    9. paper_type: Type of study (Clinical Trial, Review, Case Report, etc.)
    10. case_report: Is this a case report? (true/false)
    11. series_of_case_reports: Is this a case series? (true/false)
    12. cell_studies: Does it include cell/in-vitro studies? (true/false)
    13. mice_studies: Does it include animal/mouse studies? (true/false)
    14. novel_findings: Does it present novel findings? (true/false)
    15. key_findings: Brief summary of main findings (1-2 sentences)
    16. drugs_tested: Were any drugs tested? (true/false)
    17. clinical_study: Is this a clinical study? (true/false)
    18. clinical_study_on_children: Is this a clinical study on children? (true/false)
    19. novelty: Does it present novel/innovative approaches? (true/false)
    {custom_cat_prompts}
    
    Return as JSON array with one object per article.
    
    Articles:
    """
    
    # Format articles for analysis - use FULL CONTENT
    articles_text = ""
    for _, article in articles_df.iterrows():
        # Use full content if available, otherwise use abstract
        content = article.get('content', article.get('abstract', ''))
        articles_text += f"""\n---\nPMID: {article['PMID']}
Article Content: {content[:3000]}...\n"""  # Limit content for token management
    
    response = client.models.generate_content(
        model=MODEL_ID,
        contents=[prompt + articles_text],
        config=GenerateContentConfig(
            temperature=0,
            response_mime_type="application/json",
        )
    )
    
    try:
        return json.loads(response.text)
    except:
        print("Failed to parse response:", response.text)
        return []

### 8. Complete Medical Analysis Pipeline

Now let's put it all together to analyze medical cases.

In [None]:
def process_medical_case(case_text, num_articles=10):
    """Complete pipeline to process medical case notes."""
    
    print("🔬 Extracting medical information...")
    # Extract disease and events
    medical_info = extract_medical_info(case_text)
    disease = medical_info.get('disease', '')
    events_str = medical_info.get('events', '')
    events_list = [e.strip() for e in events_str.split(',') if e.strip()]
    
    print(f"\n📋 Disease: {disease}")
    print(f"🧬 Actionable Events: {', '.join(events_list)}")
    
    print(f"\n🔍 Searching PubMed for relevant articles...")
    # Search for articles
    articles_df = search_pubmed_articles(disease, events_list, top_k=num_articles)
    print(f"Found {len(articles_df)} articles")
    
    # Rename columns to match expected format
    articles_df = articles_df.rename(columns={'PMID': 'pmid', 'PMCID': 'pmcid'})
    
    print("\n📊 Analyzing articles...")
    # Analyze articles in batches
    batch_size = 5
    all_analyses = []
    
    for i in range(0, len(articles_df), batch_size):
        batch_df = articles_df.iloc[i:i+batch_size]
        batch_analysis = analyze_article_batch(batch_df, disease, events_list, scoring_config)
        all_analyses.extend(batch_analysis)
    
    # Merge analysis with article data
    for idx, analysis in enumerate(all_analyses):
        if idx < len(articles_df):
            for key, value in analysis.items():
                if key == 'actionable_events':
                    # Mark which events match the query
                    matched_events = []
                    for event in value:
                        matched = any(qe.lower() in event.lower() for qe in events_list)
                        matched_events.append({
                            'event': event,
                            'matches_query': matched
                        })
                    articles_df.loc[articles_df.index[idx], key] = matched_events
                else:
                    articles_df.loc[articles_df.index[idx], key] = value
    
    # Calculate scores
    config = scoring_config.get_config()
    scores = []
    breakdowns = []
    
    for _, article in articles_df.iterrows():
        metadata = article.to_dict()
        score, breakdown = calculate_article_score(metadata, config, disease)
        scores.append(score)
        breakdowns.append(breakdown)
    
    articles_df['score'] = scores
    articles_df['score_breakdown'] = breakdowns
    
    # Sort by score
    articles_df = articles_df.sort_values('score', ascending=False)
    
    return {
        'disease': disease,
        'events': events_list,
        'articles': articles_df,
        'case_text': case_text
    }

### 9. Results Visualization Functions

In [None]:
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display, HTML, Markdown

def visualize_results(results):
    """Create visualizations for the analysis results."""
    articles_df = results['articles']
    
    # Score distribution chart
    fig_scores = px.bar(
        articles_df.head(10),
        x='pmid',
        y='score',
        title='Top 10 Articles by Score',
        labels={'pmid': 'PMID', 'score': 'Score'},
        color='score',
        color_continuous_scale='viridis'
    )
    fig_scores.update_layout(xaxis_tickangle=-45)
    fig_scores.show()
    
    # Score breakdown for top article
    if len(articles_df) > 0:
        top_article = articles_df.iloc[0]
        breakdown = top_article['score_breakdown']
        
        fig_breakdown = go.Figure(data=[
            go.Bar(
                x=list(breakdown.keys()),
                y=list(breakdown.values()),
                marker_color=['green' if v > 0 else 'red' for v in breakdown.values()]
            )
        ])
        fig_breakdown.update_layout(
            title=f"Score Breakdown for Top Article (PMID: {top_article['pmid']})",
            xaxis_title="Scoring Factor",
            yaxis_title="Points"
        )
        fig_breakdown.show()

def display_top_articles(results, n=5):
    """Display detailed information about top articles."""
    articles_df = results['articles']
    
    for idx, (_, article) in enumerate(articles_df.head(n).iterrows()):
        display(HTML(f"""
        <div style="border: 2px solid #ddd; padding: 15px; margin: 10px 0; border-radius: 5px;">
            <h3>#{idx + 1} - Score: {article['score']:.1f}</h3>
            <p><strong>Title:</strong> {article['title']}</p>
            <p><strong>PMID:</strong> <a href="https://pubmed.ncbi.nlm.nih.gov/{article['pmid']}" target="_blank">{article['pmid']}</a></p>
            <p><strong>Journal:</strong> {article['journal']} ({article['year']})</p>
            <p><strong>Key Findings:</strong> {article.get('key_findings', 'N/A')}</p>
            <details>
                <summary>Score Breakdown</summary>
                <ul>
                    {''.join([f"<li>{k}: {v}</li>" for k, v in article['score_breakdown'].items()])}
                </ul>
            </details>
        </div>
        """))

### 10. Treatment Recommendation Functions

In [None]:
def generate_treatment_recommendations(results, n_articles=5):
    """Generate treatment recommendations based on top articles."""
    
    top_articles = results['articles'].head(n_articles)
    
    # Build context from top articles
    articles_context = ""
    for _, article in top_articles.iterrows():
        articles_context += f"""
PMID: {article['pmid']}
Title: {article['title']}
Key Findings: {article.get('key_findings', 'N/A')}
Score: {article['score']}
---
"""
    
    prompt = f"""Based on the following PubMed research articles about {results['disease']} 
with actionable events {', '.join(results['events'])}, provide evidence-based treatment recommendations.

Case Summary:
{results['case_text'][:500]}...

Top Research Articles:
{articles_context}

Please provide:
1. Primary treatment recommendations with specific drug names/protocols
2. Alternative treatment options
3. Monitoring recommendations
4. Key clinical considerations

Cite specific PMIDs to support each recommendation.
"""
    
    response = client.models.generate_content(
        model=MODEL_ID,
        contents=[prompt],
        config=GenerateContentConfig(
            temperature=0.1,
            max_output_tokens=2048,
        )
    )
    
    return response.text

### 11. Interactive Medical Consultation Chat

In [None]:
class MedicalConsultationChat:
    """Interactive chat for medical consultation based on analysis results."""
    
    def __init__(self, results):
        self.results = results
        self.chat_history = []
        self.context = self._build_context()
        
    def _build_context(self):
        """Build context from analysis results."""
        top_articles = self.results['articles'].head(5)
        
        context = f"""Medical Case Context:
Disease: {self.results['disease']}
Actionable Events: {', '.join(self.results['events'])}

Top Research Articles:
"""
        
        for _, article in top_articles.iterrows():
            context += f"""
- PMID {article['pmid']}: {article['title']}
  Key findings: {article.get('key_findings', 'N/A')}
"""
        
        return context
    
    def ask(self, question):
        """Ask a question about the medical case."""
        
        # Build conversation history
        history_str = ""
        for q, a in self.chat_history[-5:]:  # Keep last 5 exchanges
            history_str += f"\nUser: {q}\nAssistant: {a}\n"
        
        prompt = f"""{self.context}

Previous conversation:
{history_str}

Current question: {question}

Please provide a detailed, evidence-based response. Cite specific PMIDs when referencing research.
"""
        
        response = client.models.generate_content(
            model=MODEL_ID,
            contents=[prompt],
            config=GenerateContentConfig(
                temperature=0.3,
                max_output_tokens=1024,
            )
        )
        
        answer = response.text
        self.chat_history.append((question, answer))
        
        return answer
    
    def display_chat_interface(self):
        """Display interactive chat interface."""
        
        output = widgets.Output()
        question_input = widgets.Textarea(
            placeholder='Ask a question about the medical case...',
            layout={'width': '100%', 'height': '80px'}
        )
        submit_button = widgets.Button(
            description='Ask',
            button_style='primary',
            icon='paper-plane'
        )
        
        def on_submit(b):
            question = question_input.value.strip()
            if question:
                with output:
                    display(HTML(f'<div style="background-color: #f0f0f0; padding: 10px; margin: 5px 0; border-radius: 5px;"><b>You:</b> {question}</div>'))
                    
                    # Get response
                    answer = self.ask(question)
                    
                    display(HTML(f'<div style="background-color: #e3f2fd; padding: 10px; margin: 5px 0; border-radius: 5px;"><b>Medical Consultant:</b><br>{answer.replace(chr(10), "<br>")}</div>'))
                    
                question_input.value = ''
        
        submit_button.on_click(on_submit)
        
        # Display interface
        display(widgets.VBox([
            widgets.HTML('<h3>💬 Medical Consultation Chat</h3>'),
            widgets.HTML(f'<p>Disease: <b>{self.results["disease"]}</b></p>'),
            output,
            widgets.HBox([question_input, submit_button])
        ]))

## Complete Example: Analyzing a Medical Case

Let's analyze a sample pediatric leukemia case through the complete pipeline.

In [None]:
# Sample case notes
SAMPLE_CASE = """
A 4-year-old male presents with a 3-week history of progressive fatigue, pallor, and easy bruising. 
Physical examination reveals hepatosplenomegaly and scattered petechiae. 

Laboratory findings:
- WBC: 45,000/μL with 80% blasts
- Hemoglobin: 7.2 g/dL
- Platelets: 32,000/μL

Flow cytometry: CD33+, CD13+, CD117+, CD34+, HLA-DR+, CD19-, CD3-

Cytogenetics: 46,XY,t(9;11)(p21.3;q23.3)
Molecular: KMT2A-MLLT3 fusion detected, FLT3-ITD positive, NRAS G12D mutation

Diagnosis: KMT2A-rearranged acute myeloid leukemia (AML)
"""

print("📋 Sample Case Notes:")
print(SAMPLE_CASE)

### Step 1: Run the Analysis Pipeline

In [None]:
# Process the case
results = process_medical_case(SAMPLE_CASE, num_articles=10)

### Step 2: Visualize Results

In [None]:
# Visualize the results
visualize_results(results)

# Display top articles
display_top_articles(results, n=5)

### Step 3: Generate Treatment Recommendations

In [None]:
# Generate treatment recommendations
recommendations = generate_treatment_recommendations(results)

display(HTML('<h3>🏥 Treatment Recommendations</h3>'))
display(Markdown(recommendations))

### Step 4: Interactive Medical Consultation

In [None]:
# Create interactive chat
med_chat = MedicalConsultationChat(results)
med_chat.display_chat_interface()

# Example questions to ask:
# - "What is the prognosis for this specific KMT2A rearrangement?"
# - "Are there any ongoing clinical trials for FLT3-ITD positive pediatric AML?"
# - "What are the key monitoring parameters during treatment?"
# - "How does the NRAS mutation affect treatment selection?"

## Cleaning up

To avoid incurring charges to your Google Cloud account for the resources used in this notebook, follow these steps:

1. To avoid unnecessary Google Cloud charges, use the [Google Cloud console](https://console.cloud.google.com/) to delete your project if you do not need it. Learn more in the Google Cloud documentation for [managing and deleting your project](https://cloud.google.com/resource-manager/docs/creating-managing-projects).
2. Disable the [Vertex AI API](https://console.cloud.google.com/apis/api/aiplatform.googleapis.com) in the Google Cloud Console.