In [1]:
!pip install -q transformers accelerate bitsandbytes sentence-transformers
!pip install -q chromadb pandas pyarrow tqdm streamlit
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

print("‚úì All packages installed")

‚úì All packages installed



CELL 2: IMPORTS AND MOUNT GOOGLE DRIVE

In [2]:
import os
import sys
import json
import pickle
import time
import re
import hashlib
import uuid
import gc
import warnings
from pathlib import Path
from datetime import datetime
from collections import defaultdict, Counter, OrderedDict
from typing import List, Dict, Optional, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm

# Transformers and quantization
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)

# Embeddings and vector DB
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings

# Streamlit
import streamlit as st
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets

warnings.filterwarnings('ignore')

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)
print("‚úì Google Drive mounted")

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"‚úì Device: {device}")
if device == "cuda":
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"  GPU: {gpu_name}")
    print(f"  Memory: {gpu_memory:.2f}GB")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úì Google Drive mounted
‚úì Device: cuda
  GPU: Tesla T4
  Memory: 15.83GB


CELL 3: CONFIGURATION CLASS

In [3]:
class Config:
    """Central configuration for the RAG pipeline with Google Drive paths."""

    # Google Drive Paths
    DRIVE_ROOT = Path("/content/drive/MyDrive/Clinical_RAG_System")
    DATASET_ROOT = Path("/content/drive/MyDrive/mimic-iv-ext-direct-1.0.0/mimic-iv-ext-direct-1.0.0/Finished")

    # Persistent storage directories
    PROCESSED_DATA_DIR = DRIVE_ROOT / "processed_data"
    CHROMA_DB_PATH = DRIVE_ROOT / "chroma_db"
    MODELS_CACHE = DRIVE_ROOT / "models_cache"
    CHECKPOINTS_DIR = DRIVE_ROOT / "checkpoints"

    # Cache files
    PROCESSED_DF_PATH = PROCESSED_DATA_DIR / "processed_documents.parquet"
    EMBEDDINGS_PATH = PROCESSED_DATA_DIR / "embeddings.npy"
    PROCESSING_STATS_PATH = PROCESSED_DATA_DIR / "processing_stats.json"

    # Models
    EMBEDDING_MODEL = "intfloat/e5-small-v2"
    GENERATION_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"

    # Processing
    CHUNK_SIZE_TOKENS = 400
    CHUNK_OVERLAP_TOKENS = 100
    MIN_TEXT_LENGTH = 50
    BATCH_SIZE = 1000

    # Generation
    MAX_NEW_TOKENS = 512
    TEMPERATURE = 0.7
    TOP_P = 0.9
    TOP_K = 50
    REPETITION_PENALTY = 1.1
    MAX_CONTEXT_TOKENS = 2000
    MAX_INPUT_LENGTH = 4096

    # Retrieval
    DEFAULT_TOP_K = 5
    QUERY_PREFIX = "query: "
    PASSAGE_PREFIX = "passage: "

    @classmethod
    def setup_directories(cls):
        """Create necessary directories in Google Drive."""
        for path in [cls.DRIVE_ROOT, cls.PROCESSED_DATA_DIR,
                     cls.CHROMA_DB_PATH, cls.MODELS_CACHE, cls.CHECKPOINTS_DIR]:
            path.mkdir(parents=True, exist_ok=True)
        print("‚úì Directories created/verified")

    @classmethod
    def check_cached_data(cls):
        """Check what data is already cached."""
        status = {
            'processed_df': cls.PROCESSED_DF_PATH.exists(),
            'embeddings': cls.EMBEDDINGS_PATH.exists(),
            'chromadb': (cls.CHROMA_DB_PATH / "chroma.sqlite3").exists(),
            'stats': cls.PROCESSING_STATS_PATH.exists()
        }
        return status

# Initialize configuration
config = Config()
config.setup_directories()

# Check cache status
cache_status = config.check_cached_data()
print("\nüì¶ Cache Status:")
for item, exists in cache_status.items():
    status = "‚úì Exists" if exists else "‚úó Missing"
    print(f"  {item}: {status}")


‚úì Directories created/verified

üì¶ Cache Status:
  processed_df: ‚úì Exists
  embeddings: ‚úì Exists
  chromadb: ‚úì Exists
  stats: ‚úì Exists


CELL 4: UTILITY FUNCTIONS

In [4]:
def cleanup_memory():
    """Aggressive GPU memory cleanup."""
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

def get_gpu_memory():
    """Get current GPU memory statistics."""
    if torch.cuda.is_available():
        return {
            'allocated': torch.cuda.memory_allocated(0) / 1e9,
            'reserved': torch.cuda.memory_reserved(0) / 1e9,
            'total': torch.cuda.get_device_properties(0).total_memory / 1e9,
            'free': (torch.cuda.get_device_properties(0).total_memory -
                    torch.cuda.memory_reserved(0)) / 1e9
        }
    return None

def clean_text(text: str) -> str:
    """Clean clinical text while preserving medical terminology."""
    if not isinstance(text, str):
        return ""

    text = re.sub(r'\[\*\*[^\]]+\*\*\]', '[REDACTED]', text)
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    text = text.encode('utf-8', errors='ignore').decode('utf-8')
    text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\t')

    return text

def simple_tokenize(text: str) -> List[str]:
    """Simple whitespace tokenization."""
    return text.split()

def chunk_text_by_tokens(text: str, chunk_size: int = 400, overlap: int = 100) -> List[str]:
    """Split text into overlapping chunks based on token count."""
    sentences = re.split(r'(?<=[.!?])\s+', text)
    chunks = []
    current_chunk = []
    current_tokens = 0

    for sentence in sentences:
        sentence_tokens = len(simple_tokenize(sentence))

        if current_tokens + sentence_tokens > chunk_size and current_chunk:
            chunks.append(' '.join(current_chunk))

            overlap_text = ' '.join(current_chunk)
            overlap_tokens = simple_tokenize(overlap_text)

            if len(overlap_tokens) > overlap:
                current_chunk = [' '.join(overlap_tokens[-overlap:])]
                current_tokens = overlap
            else:
                current_chunk = []
                current_tokens = 0

        current_chunk.append(sentence)
        current_tokens += sentence_tokens

    if current_chunk:
        chunks.append(' '.join(current_chunk))

    return chunks if chunks else [text]

def generate_doc_id(file_path: Path) -> str:
    """Generate unique document ID from file path hash."""
    hash_obj = hashlib.md5(str(file_path).encode('utf-8'))
    return str(uuid.UUID(hash_obj.hexdigest()))

def extract_text_from_json(data: dict) -> str:
    """Extract clinical text from JSON data structure."""
    text_fields = ['text', 'note', 'content', 'clinical_note', 'report',
                   'description', 'narrative', 'summary', 'findings']

    for field in text_fields:
        if field in data and isinstance(data[field], str):
            return data[field]

    text_parts = [str(value) for key, value in data.items()
                  if isinstance(value, str) and len(value) > 20]

    return ' '.join(text_parts) if text_parts else ""

def save_checkpoint(data: dict, checkpoint_name: str):
    """Save checkpoint to Google Drive."""
    checkpoint_path = config.CHECKPOINTS_DIR / f"{checkpoint_name}.pkl"
    with open(checkpoint_path, 'wb') as f:
        pickle.dump(data, f)
    print(f"‚úì Checkpoint saved: {checkpoint_name}")

def load_checkpoint(checkpoint_name: str) -> Optional[dict]:
    """Load checkpoint from Google Drive."""
    checkpoint_path = config.CHECKPOINTS_DIR / f"{checkpoint_name}.pkl"
    if checkpoint_path.exists():
        with open(checkpoint_path, 'rb') as f:
            return pickle.load(f)
    return None

print("‚úì Utility functions loaded")

‚úì Utility functions loaded


CELL 5: DATA PROCESSOR WITH PROGRESS TRACKING

In [6]:
class DataProcessor:
    """Process MIMIC-IV-EXT dataset into chunks for RAG with progress tracking."""

    def __init__(self, config: Config):
        self.config = config
        self.stats = {
            'files_processed': 0,
            'files_skipped': 0,
            'total_chunks': 0,
            'errors': []
        }

    def collect_json_files(self) -> List[Dict]:
        """Collect all JSON files from dataset."""
        print("üîç Scanning dataset directory...")
        all_files = []

        for root, dirs, files in os.walk(self.config.DATASET_ROOT):
            root_path = Path(root)
            for file in files:
                if file.endswith('.json') and not file.startswith('.'):
                    file_path = root_path / file
                    relative_path = file_path.relative_to(self.config.DATASET_ROOT)
                    path_parts = relative_path.parts

                    all_files.append({
                        'path': file_path,
                        'disease_category': path_parts[0] if len(path_parts) > 0 else "Unknown",
                        'disease_subtype': path_parts[1] if len(path_parts) > 1 else "root",
                        'filename': file
                    })

        print(f"‚úì Found {len(all_files)} JSON files")
        return all_files

    def process_files(self, json_files: List[Dict]) -> pd.DataFrame:
        """Process JSON files and create chunks with progress bar."""
        processed_documents = []

        print(f"\nüìù Processing {len(json_files)} files...")

        # Create progress bar
        pbar = tqdm(json_files, desc="Processing files", unit="file")

        for file_info in pbar:
            try:
                with open(file_info['path'], 'r', encoding='utf-8') as f:
                    data = json.load(f)

                if not isinstance(data, dict):
                    self.stats['files_skipped'] += 1
                    continue

                raw_text = extract_text_from_json(data)

                if not raw_text or len(raw_text) < self.config.MIN_TEXT_LENGTH:
                    self.stats['files_skipped'] += 1
                    continue

                cleaned_text = clean_text(raw_text)

                if not cleaned_text or len(cleaned_text) < self.config.MIN_TEXT_LENGTH:
                    self.stats['files_skipped'] += 1
                    continue

                doc_id = generate_doc_id(file_info['path'])

                chunks = chunk_text_by_tokens(
                    cleaned_text,
                    chunk_size=self.config.CHUNK_SIZE_TOKENS,
                    overlap=self.config.CHUNK_OVERLAP_TOKENS
                )

                for chunk_idx, chunk_text in enumerate(chunks):
                    if len(chunk_text) < self.config.MIN_TEXT_LENGTH:
                        continue

                    chunk_id = f"{doc_id}_chunk_{chunk_idx}"

                    metadata = {k: v for k, v in data.items()
                               if k not in ['text', 'note', 'content', 'clinical_note',
                                          'report', 'description', 'narrative', 'summary']}

                    doc_record = {
                        'doc_id': doc_id,
                        'chunk_id': chunk_id,
                        'text': chunk_text,
                        'disease_category': file_info['disease_category'],
                        'disease_subtype': file_info['disease_subtype'],
                        'source_file': str(file_info['path']),
                        'chunk_index': chunk_idx,
                        'total_chunks': len(chunks),
                        'metadata': json.dumps(metadata, ensure_ascii=False)
                    }

                    processed_documents.append(doc_record)

                self.stats['files_processed'] += 1
                self.stats['total_chunks'] += len(chunks)

                # Update progress bar description
                pbar.set_postfix({
                    'processed': self.stats['files_processed'],
                    'chunks': self.stats['total_chunks']
                })

            except Exception as e:
                self.stats['files_skipped'] += 1
                self.stats['errors'].append({'file': str(file_info['path']), 'error': str(e)})

        pbar.close()

        print(f"\n‚úì Processing complete:")
        print(f"  Files processed: {self.stats['files_processed']}")
        print(f"  Files skipped: {self.stats['files_skipped']}")
        print(f"  Total chunks created: {self.stats['total_chunks']}")
        print(f"  Errors: {len(self.stats['errors'])}")

        df = pd.DataFrame(processed_documents)
        df = df.drop_duplicates(subset=['text'], keep='first')
        df = df.reset_index(drop=True)

        print(f"  Unique chunks after deduplication: {len(df)}")

        return df

    def save_processed_data(self, df: pd.DataFrame) -> Path:
        """Save processed DataFrame to parquet in Google Drive."""
        print("\nüíæ Saving processed data to Google Drive...")
        output_path = self.config.PROCESSED_DF_PATH
        df.to_parquet(output_path, index=False, compression='snappy')

        # Save statistics
        with open(self.config.PROCESSING_STATS_PATH, 'w') as f:
            json.dump(self.stats, f, indent=2)

        print(f"‚úì Data saved: {output_path}")
        print(f"  Size: {output_path.stat().st_size / 1e6:.2f} MB")
        return output_path

    def load_processed_data(self) -> Optional[pd.DataFrame]:
        """Load processed data from Google Drive if exists."""
        if self.config.PROCESSED_DF_PATH.exists():
            print("üìÇ Loading cached processed data from Google Drive...")
            df = pd.read_parquet(self.config.PROCESSED_DF_PATH)

            if self.config.PROCESSING_STATS_PATH.exists():
                with open(self.config.PROCESSING_STATS_PATH, 'r') as f:
                    self.stats = json.load(f)

            print(f"‚úì Loaded {len(df)} chunks from cache")
            print(f"  Files processed: {self.stats.get('files_processed', 'N/A')}")
            print(f"  Total chunks: {self.stats.get('total_chunks', 'N/A')}")
            return df
        return None

print("‚úì DataProcessor class loaded")

‚úì DataProcessor class loaded


CELL 6: RUN DATA PROCESSING (Skip if cached)

In [7]:
print("="*80)
print("STEP 1: DATA PROCESSING")
print("="*80)

processor = DataProcessor(config)

# Try to load from cache first
df = processor.load_processed_data()

if df is None:
    print("\n‚ö† No cached data found. Processing dataset...")
    print("‚è± This will take 10-30 minutes depending on dataset size\n")

    # Collect files
    json_files = processor.collect_json_files()

    # Process files
    df = processor.process_files(json_files)

    # Save to Google Drive
    processor.save_processed_data(df)

    print("\n‚úì Processing complete and saved to Google Drive!")
else:
    print("\n‚úì Using cached processed data from Google Drive!")

# Display sample
print("\nüìä Sample of processed data:")
display(df.head())

print(f"\nüìà Dataset Statistics:")
print(f"  Total chunks: {len(df)}")
print(f"  Unique documents: {df['doc_id'].nunique()}")
print(f"  Disease categories: {df['disease_category'].nunique()}")
print(f"  Average chunk length: {df['text'].str.len().mean():.0f} characters")

# Save reference for later cells
globals()['processed_df'] = df

STEP 1: DATA PROCESSING
üìÇ Loading cached processed data from Google Drive...
‚úì Loaded 934 chunks from cache
  Files processed: 511
  Total chunks: 937

‚úì Using cached processed data from Google Drive!

üìä Sample of processed data:


Unnamed: 0,doc_id,chunk_id,text,disease_category,disease_subtype,source_file,chunk_index,total_chunks,metadata
0,6ae055d0-ca1e-a4b3-e7f0-d0c0d15f312e,6ae055d0-ca1e-a4b3-e7f0-d0c0d15f312e_chunk_0,She with multiple admissions for gastroparesis...,Diabetes,Type I Diabetes,/content/drive/MyDrive/mimic-iv-ext-direct-1.0...,0,2,"{""Type I diabetes$Intermedia_4"": {""ICA antibod..."
1,6ae055d0-ca1e-a4b3-e7f0-d0c0d15f312e,6ae055d0-ca1e-a4b3-e7f0-d0c0d15f312e_chunk_1,"on the monitor at times, but is coming down wi...",Diabetes,Type I Diabetes,/content/drive/MyDrive/mimic-iv-ext-direct-1.0...,1,2,"{""Type I diabetes$Intermedia_4"": {""ICA antibod..."
2,4e290b0a-219e-09a3-c65f-f3c995c9f987,4e290b0a-219e-09a3-c65f-f3c995c9f987_chunk_0,"nausea, vomiting, malaise Last night the patie...",Diabetes,Type I Diabetes,/content/drive/MyDrive/mimic-iv-ext-direct-1.0...,0,2,"{""Type I Diabetes$Intermedia_4"": {""ICA antibod..."
3,4e290b0a-219e-09a3-c65f-f3c995c9f987,4e290b0a-219e-09a3-c65f-f3c995c9f987_chunk_1,is feeling better and reports the Ativan and T...,Diabetes,Type I Diabetes,/content/drive/MyDrive/mimic-iv-ext-direct-1.0...,1,2,"{""Type I Diabetes$Intermedia_4"": {""ICA antibod..."
4,2b0f1a6c-5f6a-fabc-7f0f-014323655c78,2b0f1a6c-5f6a-fabc-7f0f-014323655c78_chunk_0,"Polyuria, polydypsia, weight loss Male who has...",Diabetes,Type I Diabetes,/content/drive/MyDrive/mimic-iv-ext-direct-1.0...,0,1,"{""Type I diabetes$Intermedia_4"": {""GADA antibo..."



üìà Dataset Statistics:
  Total chunks: 934
  Unique documents: 510
  Disease categories: 25
  Average chunk length: 2024 characters


CELL 7: EMBEDDING GENERATOR WITH CACHING

In [8]:
class EmbeddingGenerator:
    """Generate E5 embeddings with Google Drive caching."""

    def __init__(self, config: Config, device: str = None):
        self.config = config
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None

    def load_model(self):
        """Load E5 embedding model with caching."""
        print(f"üì• Loading embedding model: {self.config.EMBEDDING_MODEL}")
        print(f"   Device: {self.device}")

        # Models cache to Google Drive for faster loading
        cache_dir = str(self.config.MODELS_CACHE)

        self.model = SentenceTransformer(
            self.config.EMBEDDING_MODEL,
            device=self.device,
            cache_folder=cache_dir
        )
        self.model.eval()

        print("‚úì Embedding model loaded")

        if self.device == "cuda":
            mem = get_gpu_memory()
            print(f"  GPU Memory: {mem['allocated']:.2f}GB used / {mem['total']:.2f}GB total")

        return self.model

    def generate_embeddings(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
        """Generate embeddings with progress bar and adaptive batching."""
        print(f"\nüî¢ Generating embeddings for {len(texts)} documents...")
        print(f"   Batch size: {batch_size}")

        embeddings = []
        failed_batches = []

        pbar = tqdm(range(0, len(texts), batch_size), desc="Embedding batches", unit="batch")

        for i in pbar:
            batch = texts[i:i + batch_size]

            try:
                with torch.no_grad():
                    batch_embeddings = self.model.encode(
                        batch,
                        normalize_embeddings=True,
                        show_progress_bar=False,
                        convert_to_numpy=True,
                        batch_size=batch_size
                    )

                embeddings.append(batch_embeddings)

                # Update progress
                pbar.set_postfix({
                    'completed': len(embeddings) * batch_size,
                    'total': len(texts)
                })

                # Periodic memory cleanup
                if self.device == "cuda" and i % (batch_size * 10) == 0:
                    torch.cuda.empty_cache()

            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print(f"\n‚ö† OOM at batch {i}. Reducing batch size...")
                    cleanup_memory()
                    batch_size = max(1, batch_size // 2)
                    print(f"   New batch size: {batch_size}")

                    # Retry current batch with smaller size
                    return self.generate_embeddings(texts[i:], batch_size)
                else:
                    print(f"\n‚ùå Error in batch {i}: {str(e)}")
                    failed_batches.append(i)

        pbar.close()

        if failed_batches:
            print(f"‚ö† Warning: {len(failed_batches)} batches failed")

        final_embeddings = np.vstack(embeddings)
        print(f"‚úì Generated {len(final_embeddings)} embeddings")
        print(f"  Shape: {final_embeddings.shape}")
        print(f"  Size: {final_embeddings.nbytes / 1e6:.2f} MB")

        return final_embeddings

    def save_embeddings(self, embeddings: np.ndarray) -> Path:
        """Save embeddings to Google Drive."""
        print("\nüíæ Saving embeddings to Google Drive...")
        output_path = self.config.EMBEDDINGS_PATH
        np.save(output_path, embeddings)
        print(f"‚úì Embeddings saved: {output_path}")
        print(f"  Size: {output_path.stat().st_size / 1e6:.2f} MB")
        return output_path

    def load_embeddings(self) -> Optional[np.ndarray]:
        """Load embeddings from Google Drive if exists."""
        if self.config.EMBEDDINGS_PATH.exists():
            print("üìÇ Loading cached embeddings from Google Drive...")
            embeddings = np.load(self.config.EMBEDDINGS_PATH)
            print(f"‚úì Loaded embeddings: {embeddings.shape}")
            return embeddings
        return None

    def create_chromadb_collection(self, df: pd.DataFrame, embeddings: np.ndarray):
        """Create and populate ChromaDB collection with progress tracking."""
        print("\nüóÑÔ∏è Setting up ChromaDB...")

        # Initialize ChromaDB with Google Drive persistence
        chroma_client = chromadb.PersistentClient(
            path=str(self.config.CHROMA_DB_PATH),
            settings=Settings(anonymized_telemetry=False, allow_reset=True)
        )

        # Delete existing collection if exists
        try:
            chroma_client.delete_collection(name="clinical_notes")
            print("  Deleted existing collection")
        except:
            pass

        # Create new collection
        collection = chroma_client.create_collection(
            name="clinical_notes",
            metadata={
                "hnsw:space": "cosine",
                "description": "MIMIC-IV clinical notes with E5 embeddings",
                "created_at": datetime.now().isoformat()
            }
        )
        print("‚úì Collection created")

        # Prepare data
        print("\nüìã Preparing data for ChromaDB...")
        chroma_ids = df['chunk_id'].tolist()
        chroma_documents = df['text'].tolist()
        chroma_embeddings = embeddings.tolist()

        chroma_metadatas = []
        for idx, row in df.iterrows():
            metadata = {
                'doc_id': str(row['doc_id']),
                'disease_category': str(row['disease_category']),
                'disease_subtype': str(row['disease_subtype']),
                'chunk_index': int(row['chunk_index']),
                'total_chunks': int(row['total_chunks']),
                'source_file': str(row['source_file'])
            }
            chroma_metadatas.append(metadata)

        print(f"  Documents: {len(chroma_ids)}")
        print(f"  Embeddings: {len(chroma_embeddings)}")

        # Add in batches with progress bar
        batch_size = 500
        print(f"\nüì§ Adding documents to ChromaDB (batch size: {batch_size})...")

        pbar = tqdm(range(0, len(chroma_ids), batch_size), desc="Adding batches", unit="batch")

        for i in pbar:
            batch_end = min(i + batch_size, len(chroma_ids))

            collection.add(
                ids=chroma_ids[i:batch_end],
                documents=chroma_documents[i:batch_end],
                embeddings=chroma_embeddings[i:batch_end],
                metadatas=chroma_metadatas[i:batch_end]
            )

            pbar.set_postfix({
                'added': batch_end,
                'total': len(chroma_ids)
            })

        pbar.close()

        print(f"\n‚úì ChromaDB collection populated")
        print(f"  Total documents: {collection.count()}")
        print(f"  Storage: {self.config.CHROMA_DB_PATH}")

        return collection

print("‚úì EmbeddingGenerator class loaded")

‚úì EmbeddingGenerator class loaded


CELL 8: RUN EMBEDDING GENERATION (Skip if cached)

In [9]:
print("="*80)
print("STEP 2: EMBEDDING GENERATION")
print("="*80)

# Initialize embedder
embedder = EmbeddingGenerator(config, device)

# Try to load cached embeddings
embeddings = embedder.load_embeddings()

if embeddings is None:
    print("\n‚ö† No cached embeddings found. Generating new embeddings...")
    print("‚è± This will take 5-15 minutes depending on dataset size\n")

    # Load model
    embedder.load_model()

    # Prepare documents with E5 prefix
    print("\nüìù Preparing documents with E5 passage prefix...")
    documents_with_prefix = [
        f"{config.PASSAGE_PREFIX}{text}"
        for text in processed_df['text'].tolist()
    ]
    print(f"‚úì Prepared {len(documents_with_prefix)} documents")

    # Generate embeddings
    embeddings = embedder.generate_embeddings(documents_with_prefix, batch_size=32)

    # Save to Google Drive
    embedder.save_embeddings(embeddings)

    print("\n‚úì Embeddings generated and saved to Google Drive!")

    # Clean up model from memory
    del embedder.model
    cleanup_memory()
else:
    print("\n‚úì Using cached embeddings from Google Drive!")

print(f"\nüìä Embeddings Summary:")
print(f"  Shape: {embeddings.shape}")
print(f"  Dtype: {embeddings.dtype}")
print(f"  Memory: {embeddings.nbytes / 1e6:.2f} MB")

# Save reference for later cells
globals()['embeddings'] = embeddings

STEP 2: EMBEDDING GENERATION
üìÇ Loading cached embeddings from Google Drive...
‚úì Loaded embeddings: (934, 384)

‚úì Using cached embeddings from Google Drive!

üìä Embeddings Summary:
  Shape: (934, 384)
  Dtype: float32
  Memory: 1.43 MB


CELL 9: SETUP CHROMADB VECTOR DATABASE (Skip if exists)

In [10]:
print("="*80)
print("STEP 3: CHROMADB VECTOR DATABASE SETUP")
print("="*80)

# Check if ChromaDB already exists
chroma_db_exists = (config.CHROMA_DB_PATH / "chroma.sqlite3").exists()

if chroma_db_exists:
    print("üìÇ ChromaDB already exists in Google Drive")
    print(f"   Path: {config.CHROMA_DB_PATH}")

    # Connect to existing database
    chroma_client = chromadb.PersistentClient(
        path=str(config.CHROMA_DB_PATH),
        settings=Settings(anonymized_telemetry=False)
    )

    try:
        collection = chroma_client.get_collection(name="clinical_notes")
        print(f"‚úì Connected to existing collection")
        print(f"  Documents: {collection.count()}")

        # Get a sample to verify
        sample = collection.peek(limit=1)
        print(f"  Sample metadata: {sample['metadatas'][0] if sample['metadatas'] else 'None'}")

    except Exception as e:
        print(f"‚ö† Error accessing collection: {e}")
        print("  Will recreate collection...")
        chroma_db_exists = False

if not chroma_db_exists:
    print("\n‚ö† ChromaDB not found. Creating new collection...")
    print("‚è± This will take 3-10 minutes\n")

    # Reinitialize embedder (without model, just for DB creation)
    embedder_for_db = EmbeddingGenerator(config, device)

    # Create ChromaDB collection
    collection = embedder_for_db.create_chromadb_collection(processed_df, embeddings)

    print("\n‚úì ChromaDB collection created and saved to Google Drive!")

print("\n‚úÖ ChromaDB Ready!")
print(f"   Location: {config.CHROMA_DB_PATH}")
print(f"   Documents: {collection.count()}")

# Save reference for later cells
globals()['chroma_collection'] = collection

STEP 3: CHROMADB VECTOR DATABASE SETUP
üìÇ ChromaDB already exists in Google Drive
   Path: /content/drive/MyDrive/Clinical_RAG_System/chroma_db
‚úì Connected to existing collection
  Documents: 934
  Sample metadata: {'source_file': '/content/drive/MyDrive/mimic-iv-ext-direct-1.0.0/mimic-iv-ext-direct-1.0.0/Finished/Diabetes/Type I Diabetes/17517983-DS-78.json', 'chunk_index': 0, 'disease_category': 'Diabetes', 'total_chunks': 2, 'doc_id': '6ae055d0-ca1e-a4b3-e7f0-d0c0d15f312e', 'disease_subtype': 'Type I Diabetes'}

‚úÖ ChromaDB Ready!
   Location: /content/drive/MyDrive/Clinical_RAG_System/chroma_db
   Documents: 934


CELL 10: TEST CHROMADB QUERIES

In [11]:
print("="*80)
print("TESTING CHROMADB QUERIES")
print("="*80)

# Load ChromaDB if not already loaded
if 'chroma_collection' not in globals():
    chroma_client = chromadb.PersistentClient(
        path=str(config.CHROMA_DB_PATH),
        settings=Settings(anonymized_telemetry=False)
    )
    chroma_collection = chroma_client.get_collection(name="clinical_notes")

# Load embedding model for queries
print("\nüì• Loading embedding model for queries...")
query_model = SentenceTransformer(
    config.EMBEDDING_MODEL,
    device=device,
    cache_folder=str(config.MODELS_CACHE)
)
query_model.eval()
print("‚úì Model loaded")

def test_query(query_text: str, top_k: int = 3):
    """Test a query against ChromaDB."""
    print(f"\nüîç Query: '{query_text}'")

    # Encode query with E5 prefix
    query_with_prefix = f"{config.QUERY_PREFIX}{query_text}"
    query_embedding = query_model.encode(
        query_with_prefix,
        normalize_embeddings=True,
        convert_to_numpy=True
    )

    # Search ChromaDB
    results = chroma_collection.query(
        query_embeddings=[query_embedding.tolist()],
        n_results=top_k
    )

    # Display results
    print(f"\nüìÑ Top {top_k} Results:")
    for i, (doc, metadata, distance) in enumerate(zip(
        results['documents'][0],
        results['metadatas'][0],
        results['distances'][0]
    ), 1):
        similarity = 1 - distance
        print(f"\n--- Result {i} (Similarity: {similarity:.3f}) ---")
        print(f"Category: {metadata.get('disease_category', 'N/A')}")
        print(f"Subtype: {metadata.get('disease_subtype', 'N/A')}")
        print(f"Text preview: {doc[:200]}...")

# Test queries
test_queries = [
    "What are the symptoms of pneumonia?",
    "Treatment options for heart failure",
    "Diagnosis of diabetes mellitus"
]

print("\n" + "="*80)
print("RUNNING TEST QUERIES")
print("="*80)

for query in test_queries:
    test_query(query, top_k=2)
    print("\n" + "-"*80)

print("\n‚úì ChromaDB testing complete!")

# Clean up test model
del query_model
cleanup_memory()

TESTING CHROMADB QUERIES

üì• Loading embedding model for queries...


model.safetensors:   0%|          | 0.00/133M [00:00<?, ?B/s]

‚úì Model loaded

RUNNING TEST QUERIES

üîç Query: 'What are the symptoms of pneumonia?'

üìÑ Top 2 Results:

--- Result 1 (Similarity: 0.874) ---
Category: Pneumonia
Subtype: Bacterial Pneumonia
Text preview: Her symptoms began approximately 10 days ago with a mild headache, muscle and joint pain, and a runny nose with clear discharge. Her condition temporarily improved after taking Tylenol, but she soon e...

--- Result 2 (Similarity: 0.870) ---
Category: Pneumonia
Subtype: Bacterial Pneumonia
Text preview: On ___, he developed a nonproductive cough, which was progressively worsened. He denies any SOB. He felt very weak on ___. He was diagnosed with pneumonia in the ED yesterday, sent home on zpack. He d...

--------------------------------------------------------------------------------

üîç Query: 'Treatment options for heart failure'

üìÑ Top 2 Results:

--- Result 1 (Similarity: 0.856) ---
Category: Atrial Fibrillation
Subtype: Persistent Atrial Fibrillation
Text preview: Thi

CELL 11: RAG PIPELINE WITH MISTRAL-7B

In [5]:
class RAGPipeline:
    """Complete RAG pipeline with retrieval and generation."""

    def __init__(self, config: Config):
        self.config = config
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Models (loaded on demand)
        self.embedding_model = None
        self.generation_model = None
        self.tokenizer = None
        self.collection = None

        self.stats = {
            'queries_processed': 0,
            'total_retrieval_time': 0,
            'total_generation_time': 0,
            'total_tokens_generated': 0,
            'errors': []
        }

    def load_models(self, load_generation_model: bool = True):
        """Load all required models with progress tracking."""
        print("="*80)
        print("LOADING RAG PIPELINE MODELS")
        print("="*80)

        # 1. Load embedding model
        print("\n[1/3] üì• Loading embedding model...")
        self.embedding_model = SentenceTransformer(
            self.config.EMBEDDING_MODEL,
            device=self.device,
            cache_folder=str(self.config.MODELS_CACHE)
        )
        self.embedding_model.eval()
        print("      ‚úì Embedding model loaded")

        if self.device == "cuda":
            mem = get_gpu_memory()
            print(f"      GPU Memory: {mem['allocated']:.2f}GB / {mem['total']:.2f}GB")

        # 2. Load ChromaDB
        print("\n[2/3] üìÇ Loading ChromaDB collection...")
        chroma_client = chromadb.PersistentClient(
            path=str(self.config.CHROMA_DB_PATH),
            settings=Settings(anonymized_telemetry=False)
        )
        self.collection = chroma_client.get_collection(name="clinical_notes")
        print(f"      ‚úì Collection loaded ({self.collection.count()} documents)")

        # 3. Load generation model (optional, heavy)
        if load_generation_model:
            print("\n[3/3] ü§ñ Loading Mistral-7B (4-bit quantization)...")
            print("      ‚è± This may take 2-5 minutes...")

            # Configure 4-bit quantization
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )

            # Load tokenizer
            print("      Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.config.GENERATION_MODEL,
                cache_dir=str(self.config.MODELS_CACHE)
            )

            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

            print("      ‚úì Tokenizer loaded")

            # Load model
            print("      Loading Mistral-7B model...")
            self.generation_model = AutoModelForCausalLM.from_pretrained(
                self.config.GENERATION_MODEL,
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True,
                cache_dir=str(self.config.MODELS_CACHE),
                low_cpu_mem_usage=True
            )

            print("      ‚úì Mistral-7B loaded")

            if self.device == "cuda":
                mem = get_gpu_memory()
                print(f"      GPU Memory: {mem['allocated']:.2f}GB / {mem['total']:.2f}GB")
        else:
            print("\n[3/3] ‚è≠Ô∏è  Skipping generation model (retrieval-only mode)")

        print("\n" + "="*80)
        print("‚úÖ ALL MODELS LOADED SUCCESSFULLY")
        print("="*80)

    def encode_query(self, query: str) -> np.ndarray:
        """Encode query with E5 prefix."""
        prefixed_query = f"{self.config.QUERY_PREFIX}{query}"
        with torch.no_grad():
            embedding = self.embedding_model.encode(
                prefixed_query,
                normalize_embeddings=True,
                convert_to_numpy=True
            )
        return embedding

    def retrieve_documents(self, query: str, top_k: int = 5,
                          filters: Optional[Dict] = None) -> Tuple[List[Dict], float]:
        """Retrieve relevant documents from ChromaDB."""
        start_time = time.time()

        query_embedding = self.encode_query(query)

        query_params = {
            'query_embeddings': [query_embedding.tolist()],
            'n_results': top_k
        }

        if filters:
            query_params['where'] = filters

        results = self.collection.query(**query_params)

        retrieved_docs = []
        for i, (doc_id, document, metadata, distance) in enumerate(zip(
            results['ids'][0],
            results['documents'][0],
            results['metadatas'][0],
            results['distances'][0]
        )):
            retrieved_docs.append({
                'rank': i + 1,
                'doc_id': doc_id,
                'text': document,
                'similarity': 1 - distance,
                'distance': distance,
                'metadata': metadata
            })

        return retrieved_docs, time.time() - start_time

    def format_context(self, documents: List[Dict]) -> str:
        """Format retrieved documents into context string."""
        context_parts = []
        current_tokens = 0

        for doc in documents:
            doc_text = f"""
Document {doc['rank']} [Disease: {doc['metadata'].get('disease_category', 'Unknown')}]:
{doc['text']}
---
"""
            doc_tokens = len(doc_text) // 4

            if current_tokens + doc_tokens > self.config.MAX_CONTEXT_TOKENS:
                remaining_chars = (self.config.MAX_CONTEXT_TOKENS - current_tokens) * 4
                if remaining_chars > 100:
                    doc_text = doc_text[:remaining_chars] + "...\n---\n"
                    context_parts.append(doc_text)
                break

            context_parts.append(doc_text)
            current_tokens += doc_tokens

        return "\n".join(context_parts)

    def create_prompt(self, query: str, context: str) -> str:
        """Create Mistral-formatted prompt."""
        system_instruction = """You are a clinical AI assistant with expertise in medical diagnostics and patient care. Your role is to provide accurate, evidence-based answers using the provided clinical notes.

Guidelines:
- Base your answers strictly on the provided clinical context
- Cite specific information from the documents when possible
- Use clear, professional medical terminology
- If the context doesn't contain sufficient information, clearly state what's missing
- Never fabricate medical information or make unsupported claims
- Consider differential diagnoses when appropriate
- Acknowledge uncertainty when present in the data"""

        prompt = f"""<s>[INST] {system_instruction}

Clinical Context from Patient Records:
{context}

Based on the clinical context above, answer the following question:

Question: {query}

Provide a clear, structured, evidence-based answer. [/INST]"""

        return prompt

    def generate_answer(self, query: str, top_k: int = 5,
                       filters: Optional[Dict] = None,
                       show_progress: bool = True) -> Dict:
        """Complete RAG pipeline: retrieve, format, and generate."""
        pipeline_start = time.time()

        if show_progress:
            print(f"\nüîç Processing query: '{query}'")

        try:
            # Retrieval
            if show_progress:
                print("   [1/3] Retrieving documents...")

            documents, retrieval_time = self.retrieve_documents(query, top_k, filters)

            if show_progress:
                print(f"   ‚úì Retrieved {len(documents)} documents ({retrieval_time:.2f}s)")

            if not documents:
                return {
                    'query': query,
                    'answer': "No relevant documents found in the database.",
                    'sources': [],
                    'metadata': {
                        'retrieval_time': retrieval_time,
                        'generation_time': 0,
                        'total_time': time.time() - pipeline_start,
                        'error': 'No documents retrieved'
                    }
                }

            # Format context
            if show_progress:
                print("   [2/3] Formatting context...")
            context = self.format_context(documents)

            # Generate
            if show_progress:
                print("   [3/3] Generating answer...")

            prompt = self.create_prompt(query, context)

            generation_start = time.time()

            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=self.config.MAX_INPUT_LENGTH
            ).to(self.generation_model.device)

            input_length = inputs['input_ids'].shape[1]

            gen_config = {
                'max_new_tokens': self.config.MAX_NEW_TOKENS,
                'temperature': self.config.TEMPERATURE,
                'top_p': self.config.TOP_P,
                'top_k': self.config.TOP_K,
                'repetition_penalty': self.config.REPETITION_PENALTY,
                'do_sample': True,
                'pad_token_id': self.tokenizer.pad_token_id,
                'eos_token_id': self.tokenizer.eos_token_id,
            }

            with torch.no_grad():
                outputs = self.generation_model.generate(**inputs, **gen_config)

            generated_text = self.tokenizer.decode(
                outputs[0][input_length:],
                skip_special_tokens=True
            )

            generation_time = time.time() - generation_start
            output_length = len(outputs[0]) - input_length

            if show_progress:
                print(f"   ‚úì Generated {output_length} tokens ({generation_time:.2f}s)")

            # Cleanup
            del inputs, outputs
            cleanup_memory()

            # Statistics
            self.stats['queries_processed'] += 1
            self.stats['total_retrieval_time'] += retrieval_time
            self.stats['total_generation_time'] += generation_time
            self.stats['total_tokens_generated'] += output_length

            return {
                'query': query,
                'answer': generated_text.strip(),
                'sources': documents,
                'metadata': {
                    'retrieval_time': retrieval_time,
                    'generation_time': generation_time,
                    'total_time': time.time() - pipeline_start,
                    'input_tokens': input_length,
                    'output_tokens': output_length,
                    'documents_retrieved': len(documents),
                    'success': True
                }
            }

        except Exception as e:
            if show_progress:
                print(f"   ‚ùå Error: {str(e)}")

            return {
                'query': query,
                'answer': f"ERROR: {str(e)}",
                'sources': [],
                'metadata': {
                    'error': str(e),
                    'success': False
                }
            }

    def print_result(self, result: Dict):
        """Pretty print a query result."""
        print("\n" + "="*80)
        print("QUERY RESULT")
        print("="*80)
        print(f"\n‚ùì Query: {result['query']}")
        print(f"\nüí¨ Answer:\n{result['answer']}")

        if result['sources']:
            print(f"\nüìö Sources ({len(result['sources'])} documents):")
            for source in result['sources'][:3]:  # Show top 3
                print(f"\n  ‚Ä¢ Document {source['rank']} (Similarity: {source['similarity']:.3f})")
                print(f"    Category: {source['metadata'].get('disease_category', 'N/A')}")
                print(f"    Preview: {source['text'][:150]}...")

        if 'metadata' in result and result['metadata'].get('success'):
            meta = result['metadata']
            print(f"\n‚è±Ô∏è  Performance:")
            print(f"    Retrieval: {meta['retrieval_time']:.2f}s")
            print(f"    Generation: {meta['generation_time']:.2f}s")
            print(f"    Total: {meta['total_time']:.2f}s")
            print(f"    Tokens: {meta['output_tokens']} generated")

        print("\n" + "="*80)

print("‚úì RAGPipeline class loaded")

‚úì RAGPipeline class loaded


In [13]:
"""
===============================================================================
CELL 11B: CLEAR CHROMADB SINGLETON (Run if you get errors)
===============================================================================
"""

# Clear ChromaDB singleton
import chromadb
from chromadb.api.client import SharedSystemClient

# Method 1: Clear internal cache
if hasattr(SharedSystemClient, '_identifer_to_system'):
    SharedSystemClient._identifer_to_system.clear()
    print("‚úì Cleared _identifer_to_system cache")

# Method 2: Clear system cache
if hasattr(SharedSystemClient, 'clear_system_cache'):
    SharedSystemClient.clear_system_cache()
    print("‚úì Cleared system cache")

# Method 3: Delete and recreate (nuclear option)
import sys
if 'chromadb' in sys.modules:
    del sys.modules['chromadb']
    import chromadb
    print("‚úì Reimported chromadb module")

cleanup_memory()
print("\n‚úÖ ChromaDB singleton reset complete!")

‚úì Cleared system cache
‚úì Reimported chromadb module

‚úÖ ChromaDB singleton reset complete!


CELL 12: INITIALIZE RAG PIPELINE

In [None]:
print("="*80)
print("INITIALIZING COMPLETE RAG PIPELINE")
print("="*80)

# Initialize pipeline
rag_pipeline = RAGPipeline(config)

# Load models (this will take 3-5 minutes on first run)
rag_pipeline.load_models(load_generation_model=True)

print("\n‚úÖ RAG Pipeline is ready!")
print("üí° Available as: rag_pipeline")

# Show GPU memory after loading
if torch.cuda.is_available():
    mem = get_gpu_memory()
    print(f"\nüìä Final GPU Memory Usage:")
    print(f"   Allocated: {mem['allocated']:.2f}GB")
    print(f"   Total: {mem['total']:.2f}GB")
    print(f"   Free: {mem['free']:.2f}GB")

INITIALIZING COMPLETE RAG PIPELINE
LOADING RAG PIPELINE MODELS

[1/3] üì• Loading embedding model...
      ‚úì Embedding model loaded
      GPU Memory: 0.13GB / 15.83GB

[2/3] üìÇ Loading ChromaDB collection...
      ‚úì Collection loaded (934 documents)

[3/3] ü§ñ Loading Mistral-7B (4-bit quantization)...
      ‚è± This may take 2-5 minutes...
      Loading tokenizer...
      ‚úì Tokenizer loaded
      Loading Mistral-7B model...


Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

CELL 13: TEST RAG PIPELINE WITH SAMPLE QUERIES

In [None]:
print("="*80)
print("TESTING RAG PIPELINE")
print("="*80)

# Define test queries
test_queries = [
    "What are the common symptoms of pneumonia?",
    "How is heart failure typically treated?",
    "What are the complications of diabetes mellitus?",
    "Describe the diagnostic criteria for sepsis",
    "What medications are used for hypertension?"
]

print(f"\nüß™ Running {len(test_queries)} test queries...\n")

results = []

for i, query in enumerate(test_queries, 1):
    print(f"\n{'='*80}")
    print(f"TEST QUERY {i}/{len(test_queries)}")
    print(f"{'='*80}")

    # Generate answer
    result = rag_pipeline.generate_answer(query, top_k=5, show_progress=True)
    results.append(result)

    # Print result
    rag_pipeline.print_result(result)

    # Short pause between queries
    time.sleep(1)

# Summary statistics
print("\n" + "="*80)
print("TEST SUMMARY")
print("="*80)
print(f"\nTotal queries processed: {rag_pipeline.stats['queries_processed']}")
print(f"Average retrieval time: {rag_pipeline.stats['total_retrieval_time'] / len(results):.2f}s")
print(f"Average generation time: {rag_pipeline.stats['total_generation_time'] / len(results):.2f}s")
print(f"Total tokens generated: {rag_pipeline.stats['total_tokens_generated']}")
print(f"Average tokens per query: {rag_pipeline.stats['total_tokens_generated'] / len(results):.0f}")


CELL 14: INTERACTIVE QUERY INTERFACE (COLAB)

In [None]:
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets

def create_query_interface():
    """Create an interactive query interface for Colab."""

    # Widgets
    query_input = widgets.Textarea(
        value='',
        placeholder='Enter your clinical question here...',
        description='Query:',
        layout=widgets.Layout(width='100%', height='80px')
    )

    top_k_slider = widgets.IntSlider(
        value=5,
        min=1,
        max=10,
        step=1,
        description='Documents:',
        style={'description_width': 'initial'}
    )

    category_dropdown = widgets.Dropdown(
        options=['All'] + sorted(processed_df['disease_category'].unique().tolist()),
        value='All',
        description='Category:',
        style={'description_width': 'initial'}
    )

    submit_button = widgets.Button(
        description='Generate Answer',
        button_style='primary',
        icon='search'
    )

    output_area = widgets.Output()

    def on_submit(button):
        """Handle query submission."""
        with output_area:
            clear_output()

            query = query_input.value.strip()
            if not query:
                print("‚ö†Ô∏è Please enter a query")
                return

            # Prepare filters
            filters = None
            if category_dropdown.value != 'All':
                filters = {'disease_category': category_dropdown.value}

            # Generate answer
            result = rag_pipeline.generate_answer(
                query,
                top_k=top_k_slider.value,
                filters=filters,
                show_progress=True
            )

            # Display result
            rag_pipeline.print_result(result)

    submit_button.on_click(on_submit)

    # Layout
    interface = widgets.VBox([
        widgets.HTML("<h2>üè• Clinical RAG System - Interactive Query Interface</h2>"),
        query_input,
        widgets.HBox([top_k_slider, category_dropdown]),
        submit_button,
        output_area
    ])

    return interface

# Create and display interface
print("üéØ Creating interactive query interface...")
query_interface = create_query_interface()
display(query_interface)

CELL 15: STREAMLIT APP CODE GENERATOR

In [None]:

def generate_streamlit_app():
    """Generate Streamlit app code."""

    streamlit_code = '''
import streamlit as st
import sys
from pathlib import Path
import json
import time
import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Add your imports and RAGPipeline class here
# Copy the Config and RAGPipeline classes from the Colab notebook

# Page config
st.set_page_config(
    page_title="Clinical RAG System",
    page_icon="üè•",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS
st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        color: #1f77b4;
        text-align: center;
        margin-bottom: 2rem;
    }
    .query-box {
        background-color: #f0f2f6;
        padding: 1rem;
        border-radius: 0.5rem;
        margin: 1rem 0;
    }
    .source-card {
        background-color: #ffffff;
        border: 1px solid #e0e0e0;
        border-radius: 0.5rem;
        padding: 1rem;
        margin: 0.5rem 0;
    }
    .metric-card {
        background-color: #f8f9fa;
        padding: 1rem;
        border-radius: 0.5rem;
        text-align: center;
    }
</style>
""", unsafe_allow_html=True)

# Initialize session state
if 'rag_pipeline' not in st.session_state:
    st.session_state.rag_pipeline = None
    st.session_state.query_history = []
    st.session_state.models_loaded = False

# Sidebar
with st.sidebar:
    st.title("‚öôÔ∏è Settings")

    # Model loading
    if not st.session_state.models_loaded:
        if st.button("üöÄ Load Models", type="primary"):
            with st.spinner("Loading models... This may take 3-5 minutes..."):
                try:
                    config = Config()  # Use your Config class
                    st.session_state.rag_pipeline = RAGPipeline(config)
                    st.session_state.rag_pipeline.load_models(load_generation_model=True)
                    st.session_state.models_loaded = True
                    st.success("‚úÖ Models loaded successfully!")
                    st.rerun()
                except Exception as e:
                    st.error(f"‚ùå Error loading models: {str(e)}")
    else:
        st.success("‚úÖ Models loaded")

        # Query parameters
        st.subheader("Query Parameters")
        top_k = st.slider("Number of documents", 1, 10, 5)

        # Optional: Category filter
        categories = ["All"] + ["Cardiovascular", "Respiratory", "Endocrine"]  # Add your categories
        category_filter = st.selectbox("Disease Category", categories)

        # Clear history
        if st.button("üóëÔ∏è Clear History"):
            st.session_state.query_history = []
            st.rerun()

# Main content
st.markdown('<h1 class="main-header">üè• Clinical RAG System</h1>', unsafe_allow_html=True)

# Check if models are loaded
if not st.session_state.models_loaded:
    st.info("üëà Please load the models from the sidebar to begin.")
    st.stop()

# Query input
st.subheader("üí¨ Ask a Clinical Question")

col1, col2 = st.columns([4, 1])

with col1:
    query = st.text_area(
        "Enter your question:",
        height=100,
        placeholder="e.g., What are the symptoms of pneumonia?"
    )

with col2:
    st.write("")  # Spacing
    st.write("")  # Spacing
    submit = st.button("üîç Generate Answer", type="primary", use_container_width=True)

# Process query
if submit and query:
    with st.spinner("üîÑ Processing your query..."):
        try:
            # Prepare filters
            filters = None
            if category_filter != "All":
                filters = {"disease_category": category_filter}

            # Generate answer
            result = st.session_state.rag_pipeline.generate_answer(
                query,
                top_k=top_k,
                filters=filters,
                show_progress=False
            )

            # Add to history
            st.session_state.query_history.insert(0, result)

            # Display results
            st.markdown("---")

            # Answer
            st.subheader("üí° Answer")
            st.markdown(f'<div class="query-box">{result["answer"]}</div>', unsafe_allow_html=True)

            # Metrics
            if result["metadata"].get("success"):
                col1, col2, col3, col4 = st.columns(4)

                with col1:
                    st.metric("‚è±Ô∏è Total Time", f"{result['metadata']['total_time']:.2f}s")
                with col2:
                    st.metric("üì• Retrieval", f"{result['metadata']['retrieval_time']:.2f}s")
                with col3:
                    st.metric("ü§ñ Generation", f"{result['metadata']['generation_time']:.2f}s")
                with col4:
                    st.metric("üìù Tokens", result['metadata']['output_tokens'])

            # Sources
            st.subheader("üìö Sources")
            for i, source in enumerate(result['sources'][:5], 1):
                with st.expander(f"Document {i} - {source['metadata'].get('disease_category', 'N/A')} (Similarity: {source['similarity']:.3f})"):
                    st.markdown(f"**Category:** {source['metadata'].get('disease_category', 'N/A')}")
                    st.markdown(f"**Subtype:** {source['metadata'].get('disease_subtype', 'N/A')}")
                    st.markdown("**Content:**")
                    st.text(source['text'][:500] + "..." if len(source['text']) > 500 else source['text'])

        except Exception as e:
            st.error(f"‚ùå Error: {str(e)}")

# Query history
if st.session_state.query_history:
    st.markdown("---")
    st.subheader("üìú Query History")

    for i, past_result in enumerate(st.session_state.query_history[:5], 1):
        with st.expander(f"{i}. {past_result['query'][:80]}..."):
            st.markdown(f"**Answer:** {past_result['answer'][:300]}...")
            if past_result['metadata'].get('success'):
                st.caption(f"‚è±Ô∏è {past_result['metadata']['total_time']:.2f}s | "
                          f"üìö {len(past_result['sources'])} sources")

# Footer
st.markdown("---")
st.caption("üè• Clinical RAG System | Powered by MIMIC-IV-EXT, E5 Embeddings, and Mistral-7B")
'''

    # Save to Google Drive
    output_path = config.DRIVE_ROOT / "streamlit_app.py"
    with open(output_path, 'w') as f:
        f.write(streamlit_code)

    print(f"‚úì Streamlit app saved to: {output_path}")
    print("\nüìù To run the Streamlit app:")
    print("1. Install: !pip install streamlit")
    print(f"2. Run: !streamlit run {output_path}")
    print("3. Or use ngrok/localtunnel to expose the app")

    return output_path

# Generate app
print("="*80)
print("GENERATING STREAMLIT APP")
print("="*80)

app_path = generate_streamlit_app()

print("\nüí° Next steps:")
print("1. Copy the Config and RAGPipeline classes into the Streamlit app")
print("2. Update the Google Drive paths in the Config class")
print("3. Run the app using: streamlit run streamlit_app.py")

CELL 16: SAVE FINAL STATE

In [None]:
def save_pipeline_state():
    """Save pipeline state to Google Drive."""

    state = {
        'config': {
            'embedding_model': config.EMBEDDING_MODEL,
            'generation_model': config.GENERATION_MODEL,
            'paths': {
                'processed_data': str(config.PROCESSED_DATA_DIR),
                'chroma_db': str(config.CHROMA_DB_PATH),
                'models_cache': str(config.MODELS_CACHE)
            }
        },
        'stats': {
            'processed_chunks': len(processed_df) if 'processed_df' in globals() else 0,
            'embeddings_shape': embeddings.shape if 'embeddings' in globals() else None,
            'chromadb_count': chroma_collection.count() if 'chroma_collection' in globals() else 0,
            'queries_processed': rag_pipeline.stats['queries_processed'] if 'rag_pipeline' in globals() else 0
        },
        'timestamp': datetime.now().isoformat()
    }

    state_path = config.DRIVE_ROOT / "pipeline_state.json"
    with open(state_path, 'w') as f:
        json.dump(state, f, indent=2)

    print(f"‚úì Pipeline state saved to: {state_path}")
    return state

print("="*80)
print("SAVING PIPELINE STATE")
print("="*80)

final_state = save_pipeline_state()

print("\nüìä Final State Summary:")
print(json.dumps(final_state, indent=2))

print("\n" + "="*80)
print("‚úÖ ALL SETUP COMPLETE!")
print("="*80)
print("\nüí° Everything is cached in Google Drive:")
print(f"   üìÅ Root: {config.DRIVE_ROOT}")
print(f"   üìÑ Processed data: {config.PROCESSED_DF_PATH}")
print(f"   üî¢ Embeddings: {config.EMBEDDINGS_PATH}")
print(f"   üóÑÔ∏è ChromaDB: {config.CHROMA_DB_PATH}")
print(f"   ü§ñ Models cache: {config.MODELS_CACHE}")
print("\nüöÄ Next time you run this notebook, it will load from cache!")
print("‚ö° Subsequent runs will be 10-20x faster!")


In [None]:
# Save Streamlit app
streamlit_code = generate_streamlit_app()
streamlit_path = config.DRIVE_ROOT / "streamlit_app.py"

with open(streamlit_path, 'w') as f:
    f.write(streamlit_code)

print(f"\n‚úì Streamlit app generated: {streamlit_path}")
print(f"  Size: {streamlit_path.stat().st_size / 1024:.2f} KB")

print("\n" + "="*80)
print("üì± HOW TO RUN STREAMLIT IN COLAB")
print("="*80)
print("\n1. Install streamlit-localtunnel:")
print("   !pip install streamlit-localtunnel")
print("\n2. Run in a new cell:")
print("   !streamlit run /content/drive/MyDrive/Clinical_RAG_System/streamlit_app.py &")
print("   !npx localtunnel --port 8501")
print("\n3. Copy the generated URL and access your app!")

print("\n‚úÖ All cells complete! Your RAG system is ready.")

In [None]:
!pip install streamlit-localtunnel
!streamlit run /content/drive/MyDrive/Clinical_RAG_System/streamlit_app.py &
!npx localtunnel --port 8501