In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
!pip install sentence-transformers neo4j torch

Collecting neo4j
  Downloading neo4j-6.0.3-py3-none-any.whl.metadata (5.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidi

In [9]:
"""Generate embeddings for all entities in Neo4j and create vector index.

Run this in Kaggle with GPU to speed up embedding generation.

Requirements:
- sentence-transformers
- neo4j
- torch (for GPU acceleration)

Usage:
    python generate_embeddings.py

Environment variables needed:
- NEO4J_URI
- NEO4J_USER  
- NEO4J_PASSWORD
"""
import os
import torch
from neo4j import GraphDatabase
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import logging
from kaggle_secrets import UserSecretsClient

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EmbeddingGenerator:
    """Generate and store embeddings for Neo4j entities."""
    
    def __init__(
        self,
        neo4j_uri: str,
        neo4j_user: str,
        neo4j_password: str,
        model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
        batch_size: int = 32
    ):
        """Initialize generator.
        
        Args:
            neo4j_uri: Neo4j connection URI
            neo4j_user: Neo4j username
            neo4j_password: Neo4j password
            model_name: HuggingFace model for embeddings
            batch_size: Batch size for embedding generation
        """
        # Detect device (GPU if available)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        logger.info(f"Using device: {self.device}")
        if self.device == 'cuda':
            print(f"GPU: {torch.cuda.get_device_name(0)}")
            logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
            logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        
        # Initialize model and move to GPU (double ensure)
        self.driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password))
        self.model = SentenceTransformer(model_name, device=self.device)
        self.model.to(self.device)  # Explicitly move model to device (belt + suspenders)
        self.batch_size = batch_size
        self.embedding_dim = self.model.get_sentence_embedding_dimension()
        
        logger.info(f"Initialized with model: {model_name}")
        logger.info(f"Embedding dimension: {self.embedding_dim}")
        logger.info(f"Batch size: {batch_size}")
        logger.info(f"Model device: {next(self.model.parameters()).device}")  # Verify actual device
    
    def get_all_entities(self):
        """Fetch all entities with ALL properties from Neo4j."""
        query = """
        MATCH (n)
        WHERE n:Country OR n:Disease OR n:Outbreak OR n:VaccinationRecord 
           OR n:Organization OR n:Vaccine OR n:PandemicEvent
        RETURN 
            elementId(n) as id,
            labels(n)[0] as type,
            properties(n) as properties
        """
        
        with self.driver.session() as session:
            result = session.run(query)
            entities = [record.data() for record in result]
        
        logger.info(f"Found {len(entities)} entities")
        return entities
    
    def create_text_representation(self, entity):
        """Create COMPREHENSIVE text representation for embedding.
        
        Includes ALL properties from the entity for maximum RAG/chatbot quality.
        Arrays are converted to comma-separated strings.
        """
        parts = []
        entity_type = entity['type']
        props = entity.get('properties', {})
        
        # Helper to format arrays
        def format_array(arr):
            if isinstance(arr, list):
                return ', '.join(str(x) for x in arr if x)
            return str(arr) if arr else ''
        
        # Add type
        parts.append(f"Entity Type: {entity_type}")
        
        # ==== DISEASE - ALL PROPERTIES ====
        if entity_type == 'Disease':
            # Basic identifiers
            if props.get('id'):
                parts.append(f"ID: {props['id']}")
            if props.get('name'):
                parts.append(f"Name: {props['name']}")
            if props.get('fullName'):
                parts.append(f"Full Name: {props['fullName']}")
            
            # Medical classification codes
            if props.get('icd10'):
                parts.append(f"ICD-10 Code: {props['icd10']}")
            if props.get('mesh'):
                parts.append(f"MeSH Code: {props['mesh']}")
            
            # Disease category and type
            if props.get('category'):
                parts.append(f"Category: {props['category']}")
            if props.get('pathogen'):
                parts.append(f"Pathogen: {props['pathogen']}")
            if props.get('causativeAgent'):
                parts.append(f"Causative Agent: {props['causativeAgent']}")
            if props.get('medicalSpecialty'):
                parts.append(f"Medical Specialty: {props['medicalSpecialty']}")
            
            # Clinical information - ARRAYS
            if props.get('symptoms'):
                parts.append(f"Symptoms: {format_array(props['symptoms'])}")
            if props.get('treatments'):
                parts.append(f"Treatments: {format_array(props['treatments'])}")
            if props.get('drugs'):
                parts.append(f"Drugs: {format_array(props['drugs'])}")
            if props.get('possibleTreatments'):
                parts.append(f"Possible Treatments: {format_array(props['possibleTreatments'])}")
            if props.get('riskFactors'):
                parts.append(f"Risk Factors: {format_array(props['riskFactors'])}")
            if props.get('transmissionMethods'):
                parts.append(f"Transmission Methods: {format_array(props['transmissionMethods'])}")
            
            # Prevention and incubation
            if props.get('prevention'):
                parts.append(f"Prevention: {props['prevention']}")
            if props.get('incubationPeriod'):
                parts.append(f"Incubation Period: {props['incubationPeriod']}")
            
            # Descriptions
            if props.get('description'):
                parts.append(f"Description: {props['description'][:1000]}")
            if props.get('wikipediaAbstract'):
                parts.append(f"Wikipedia Abstract: {props['wikipediaAbstract'][:1000]}")
            if props.get('wikipediaUrl'):
                parts.append(f"Wikipedia URL: {props['wikipediaUrl']}")
            if props.get('dbpediaUri'):
                parts.append(f"DBpedia URI: {props['dbpediaUri']}")
            if props.get('thumbnailUrl'):
                parts.append(f"Image: {props['thumbnailUrl']}")
            
            # Status
            if props.get('eradicated'):
                parts.append(f"Eradicated: {props['eradicated']}")
            if props.get('pandemic'):
                parts.append(f"Pandemic: {props['pandemic']}")
            if props.get('dataSource'):
                parts.append(f"Data Source: {props['dataSource']}")
                
        # ==== COUNTRY - ALL PROPERTIES ====
        elif entity_type == 'Country':
            if props.get('name'):
                parts.append(f"Country Name: {props['name']}")
            if props.get('code'):
                parts.append(f"Country Code: {props['code']}")
            if props.get('iso2'):
                parts.append(f"ISO-2 Code: {props['iso2']}")
            
            # Geographic info
            if props.get('continent'):
                parts.append(f"Continent: {props['continent']}")
            if props.get('capital'):
                parts.append(f"Capital: {props['capital']}")
            if props.get('latitude') and props.get('longitude'):
                parts.append(f"Coordinates: {props['latitude']}, {props['longitude']}")
            if props.get('areaKm2'):
                parts.append(f"Area: {props['areaKm2']:,} km²")
            
            # Demographics
            if props.get('population'):
                parts.append(f"Population: {props['population']:,}")
            if props.get('officialLanguage'):
                parts.append(f"Official Language: {props['officialLanguage']}")
            
            # Economic
            if props.get('gdp'):
                parts.append(f"GDP: ${props['gdp']:,}")
            if props.get('lifeExpectancy'):
                parts.append(f"Life Expectancy: {props['lifeExpectancy']} years")
            
            # External links
            if props.get('wikipediaUrl'):
                parts.append(f"Wikipedia: {props['wikipediaUrl']}")
            if props.get('dbpediaUri'):
                parts.append(f"DBpedia: {props['dbpediaUri']}")
                
        # ==== OUTBREAK - ALL PROPERTIES ====
        elif entity_type == 'Outbreak':
            if props.get('id'):
                parts.append(f"Outbreak ID: {props['id']}")
            if props.get('year'):
                parts.append(f"Year: {props['year']}")
            if props.get('date'):
                parts.append(f"Date: {props['date']}")
            
            # Case statistics (comprehensive)
            if props.get('cases'):
                parts.append(f"Cases: {int(props['cases']):,}")
            if props.get('deaths'):
                parts.append(f"Deaths: {int(props['deaths']):,}")
            if props.get('confirmedDeaths'):
                parts.append(f"Confirmed Deaths: {int(props['confirmedDeaths']):,}")
            if props.get('excessDeaths'):
                parts.append(f"Excess Deaths: {props['excessDeaths']:,}")
            if props.get('confidenceIntervalTop'):
                parts.append(f"Confidence Interval Top: {props['confidenceIntervalTop']}")
            if props.get('confidenceIntervalBottom'):
                parts.append(f"Confidence Interval Bottom: {props['confidenceIntervalBottom']}")
            
            # Vaccination statistics (for VaccinationRecord outbreaks)
            if props.get('coverage'):
                parts.append(f"Vaccination Coverage: {props['coverage']}%")
            if props.get('totalVaccinated'):
                parts.append(f"Total Vaccinated: {props['totalVaccinated']:,}")
            
            # Links to disease/country
            if props.get('diseaseId'):
                parts.append(f"Disease: {props['diseaseId']}")
            if props.get('diseaseName'):
                parts.append(f"Disease Name: {props['diseaseName']}")
            if props.get('countryCode'):
                parts.append(f"Country: {props['countryCode']}")
            if props.get('countryName'):
                parts.append(f"Country Name: {props['countryName']}")
                
        # ==== ORGANIZATION - ALL PROPERTIES ====
        elif entity_type == 'Organization':
            if props.get('name'):
                parts.append(f"Organization: {props['name']}")
            if props.get('acronym'):
                parts.append(f"Acronym: {props['acronym']}")
            if props.get('role'):
                parts.append(f"Role: {props['role']}")
            if props.get('headquarters'):
                parts.append(f"Headquarters: {props['headquarters']}")
            if props.get('founded'):
                parts.append(f"Founded: {props['founded']}")
            if props.get('website'):
                parts.append(f"Website: {props['website']}")
                
        # ==== VACCINE - ALL PROPERTIES ====
        elif entity_type == 'Vaccine':
            if props.get('name'):
                parts.append(f"Vaccine Name: {props['name']}")
            if props.get('vaccineName'):
                parts.append(f"Vaccine: {props['vaccineName']}")
            if props.get('manufacturer'):
                parts.append(f"Manufacturer: {props['manufacturer']}")
            if props.get('vaccineType'):
                parts.append(f"Vaccine Type: {props['vaccineType']}")
            if props.get('approvalDate'):
                parts.append(f"Approval Date: {props['approvalDate']}")
            if props.get('description'):
                parts.append(f"Description: {props['description'][:500]}")
                
        # ==== VACCINATION RECORD - ALL PROPERTIES ====
        elif entity_type == 'VaccinationRecord':
            if props.get('id'):
                parts.append(f"Record ID: {props['id']}")
            if props.get('vaccineName'):
                parts.append(f"Vaccine: {props['vaccineName']}")
            if props.get('year'):
                parts.append(f"Year: {props['year']}")
            if props.get('coverage'):
                parts.append(f"Coverage: {props['coverage']}%")
            if props.get('totalVaccinated'):
                parts.append(f"Total Vaccinated: {props['totalVaccinated']:,}")
            if props.get('countryCode'):
                parts.append(f"Country: {props['countryCode']}")
                
        # ==== PANDEMIC EVENT - ALL PROPERTIES ====
        elif entity_type == 'PandemicEvent':
            if props.get('name'):
                parts.append(f"Event: {props['name']}")
            if props.get('abstract'):
                parts.append(f"Description: {props['abstract'][:1000]}")
            if props.get('startDate'):
                parts.append(f"Start Date: {props['startDate']}")
            if props.get('deathToll'):
                parts.append(f"Death Toll: {props['deathToll']}")
            if props.get('location'):
                parts.append(f"Location: {props['location']}")
        
        # Join all parts with separator
        text = " | ".join(parts) if parts else "Unknown entity"
        return text
    
    def generate_embeddings(self, entities):
        """Generate embeddings for all entities using GPU if available."""
        logger.info("Generating embeddings...")
        logger.info(f"Processing {len(entities)} entities in batches of {self.batch_size}")
        
        # Prepare texts
        texts = [self.create_text_representation(e) for e in entities]
        
        # Generate embeddings in batches with GPU acceleration
        embeddings = self.model.encode(
            texts,
            batch_size=self.batch_size,
            show_progress_bar=True,
            convert_to_numpy=True,
            device=self.device,  # Explicitly use GPU
            normalize_embeddings=True  # Normalize for cosine similarity
        )
        
        logger.info(f"✓ Generated {len(embeddings)} embeddings")
        return embeddings
    
    def store_embeddings(self, entities, embeddings):
        """Store embeddings back to Neo4j.
        
        OVERWRITES existing embeddings to ensure they're up-to-date with the latest schema.
        """
        logger.info("Storing embeddings in Neo4j (will overwrite existing)...")
        
        query = """
        MATCH (n)
        WHERE elementId(n) = $id
        SET n.embedding = $embedding
        """
        
        with self.driver.session() as session:
            for entity, embedding in tqdm(zip(entities, embeddings), total=len(entities), desc="Storing"):
                session.run(query, {
                    "id": entity['id'],
                    "embedding": embedding.tolist()
                })
        
        logger.info(f"✓ Stored {len(embeddings)} embeddings (overwrote any existing)")
    
    def create_vector_index(self):
        """Create vector index for similarity search."""
        logger.info("Creating vector index...")
        
        with self.driver.session() as session:
            # Check if index exists
            result = session.run("SHOW INDEXES")
            existing = [r['name'] for r in result]
            
            if 'entityEmbedding' in existing:
                logger.info("Vector index 'entityEmbedding' already exists, dropping...")
                session.run("DROP INDEX entityEmbedding IF EXISTS")
            
            # Create vector index for all entity types
            query = f"""
            CREATE VECTOR INDEX entityEmbedding IF NOT EXISTS
            FOR (n:Country)
            ON n.embedding
            OPTIONS {{
                indexConfig: {{
                    `vector.dimensions`: {self.embedding_dim},
                    `vector.similarity_function`: 'cosine'
                }}
            }}
            """
            
            try:
                session.run(query)
                logger.info("✓ Vector index 'entityEmbedding' created")
            except Exception as e:
                logger.warning(f"Note: {e}")
                logger.info("Trying alternative index creation method...")
                
                # Alternative: Create for specific label
                for label in ['Country', 'Disease', 'Outbreak', 'VaccinationRecord', 
                             'Organization', 'Vaccine', 'PandemicEvent']:
                    try:
                        query = f"""
                        CREATE VECTOR INDEX entityEmbedding_{label} IF NOT EXISTS
                        FOR (n:{label})
                        ON n.embedding
                        OPTIONS {{
                            indexConfig: {{
                                `vector.dimensions`: {self.embedding_dim},
                                `vector.similarity_function`: 'cosine'
                            }}
                        }}
                        """
                        session.run(query)
                        logger.info(f"✓ Created index for {label}")
                    except Exception as e2:
                        logger.error(f"Failed to create index for {label}: {e2}")
    
    def verify_setup(self):
        """Verify embeddings and index are working."""
        logger.info("\nVerifying setup...")
        
        with self.driver.session() as session:
            # Count nodes with embeddings
            result = session.run("""
                MATCH (n)
                WHERE n.embedding IS NOT NULL
                RETURN count(n) as count
            """)
            count = result.single()['count']
            logger.info(f"✓ {count} nodes have embeddings")
            
            # List indexes
            result = session.run("SHOW INDEXES")
            indexes = [r['name'] for r in result]
            logger.info(f"✓ Found indexes: {', '.join(indexes)}")
    
    def close(self):
        """Close Neo4j connection."""
        self.driver.close()
    
    def run(self):
        """Run the full embedding generation pipeline."""
        try:
            logger.info("=" * 60)
            logger.info("EpiHelix - Embedding Generation")
            logger.info("=" * 60)
            
            # Step 1: Fetch entities
            entities = self.get_all_entities()
            
            if not entities:
                logger.error("No entities found in Neo4j!")
                return
            
            # Step 2: Generate embeddings
            embeddings = self.generate_embeddings(entities)
            
            # Step 3: Store embeddings
            self.store_embeddings(entities, embeddings)
            
            # Step 4: Create vector index
            self.create_vector_index()
            
            # Step 5: Verify
            self.verify_setup()
            
            logger.info("\n" + "=" * 60)
            logger.info("✓ Embedding generation complete!")
            logger.info("=" * 60)
            logger.info("\nYour backend is now ready for semantic search.")
            
        except Exception as e:
            logger.error(f"Error: {e}", exc_info=True)
        finally:
            self.close()


def main():
    """Main entry point."""
    # Get credentials from environment
    user_secrets = UserSecretsClient()
    neo4j_user = "neo4j"
    neo4j_uri = user_secrets.get_secret("NEO4J_URI")
    neo4j_password = user_secrets.get_secret("NEO4J_PASSWORD")
    
    if not neo4j_password:
        logger.error("NEO4J_PASSWORD environment variable not set!")
        logger.info("\nSet it with:")
        logger.info("  export NEO4J_PASSWORD='your-password'")
        return
    
    # Initialize and run
    generator = EmbeddingGenerator(
        neo4j_uri=neo4j_uri,
        neo4j_user=neo4j_user,
        neo4j_password=neo4j_password,
        batch_size=256  # Larger batch size for GPU (Kaggle has 16GB GPU)
    )
    
    generator.run()


if __name__ == "__main__":
    main()

GPU: Tesla P100-PCIE-16GB


Batches:   0%|          | 0/382 [00:00<?, ?it/s]

Storing:   2%|▏         | 2279/97681 [06:23<4:27:13,  5.95it/s]
ERROR:__main__:Error: {neo4j_code: Neo.TransientError.General.MemoryPoolOutOfMemoryError} {message: The allocation of an extra 2.0 MiB would use more than the limit 250.0 MiB. Currently using 249.1 MiB. dbms.memory.transaction.total.max threshold reached} {gql_status: 51N72} {gql_status_description: error: system configuration or operation exception - memory pool out of memory. Failed to allocate memory in a memory pool. See dbms.memory.transaction.total.max in the neo4j.conf file.}
Traceback (most recent call last):
  File "/tmp/ipykernel_48/249009054.py", line 449, in run
    self.store_embeddings(entities, embeddings)
  File "/tmp/ipykernel_48/249009054.py", line 348, in store_embeddings
    session.run(query, {
  File "/usr/local/lib/python3.11/dist-packages/neo4j/_sync/work/session.py", line 320, in run
    self._auto_result._run(
  File "/usr/local/lib/python3.11/dist-packages/neo4j/_sync/work/result.py", line 237, i