<a href="https://colab.research.google.com/github/Niv0902/EcoFish-/blob/main/HW1_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ecological RAG System – Premium Edition 🌊


In [1]:
# If running on Colab, uncomment to install dependencies:
# !pip -q install gradio==4.* sentence-transformers chromadb scikit-learn

In [2]:
# -*- coding: utf-8 -*-
import json
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Optional
import re
import time
import warnings
warnings.filterwarnings('ignore')

In [3]:
# Package detection
try:
    import chromadb
    CHROMADB_AVAILABLE = True
except ImportError:
    CHROMADB_AVAILABLE = False

try:
    from sentence_transformers import SentenceTransformer
    TRANSFORMERS_AVAILABLE = True
except ImportError:
    TRANSFORMERS_AVAILABLE = False

try:
    import openai
    OPENAI_AVAILABLE = True
except ImportError:
    OPENAI_AVAILABLE = False

try:
    import gradio as gr
    GRADIO_AVAILABLE = True
except ImportError:
    GRADIO_AVAILABLE = False

In [4]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

In [5]:
class SimpleVectorStore:
    """Lightweight vector store for when ChromaDB is unavailable"""

    def __init__(self):
        self.documents = []
        self.embeddings = []
        self.metadatas = []
        self.ids = []

    def add(self, embeddings, documents, metadatas, ids):
        self.embeddings.extend(embeddings)
        self.documents.extend(documents)
        self.metadatas.extend(metadatas)
        self.ids.extend(ids)

    def query(self, query_embeddings, n_results=5):
        if not self.embeddings:
            return {'ids': [[]], 'documents': [[]], 'metadatas': [[]], 'distances': [[]]}

        similarities = cosine_similarity(query_embeddings, self.embeddings)[0]
        top_indices = np.argsort(similarities)[::-1][:n_results]

        return {
            'ids': [[self.ids[i] for i in top_indices]],
            'documents': [[self.documents[i] for i in top_indices]],
            'metadatas': [[self.metadatas[i] for i in top_indices]],
            'distances': [[1 - similarities[i] for i in top_indices]]
        }

    def count(self):
        return len(self.documents)

In [6]:
class EcologicalRAG:
    """Main RAG system for ecological research papers"""

    def __init__(self, openai_api_key=None):
        self._initialize_components(openai_api_key)
        self.papers = []
        self.fitted = False

    def _initialize_components(self, openai_api_key):
        """Initialize all system components silently"""
        # Setup embedding model
        if TRANSFORMERS_AVAILABLE:
            try:
                self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
                self.use_transformers = True
            except Exception:
                self._setup_tfidf()
        else:
            self._setup_tfidf()

        # Setup vector store
        if CHROMADB_AVAILABLE:
            try:
                client = chromadb.Client()
                try:
                    self.collection = client.get_collection("ecological_papers")
                except Exception:
                    self.collection = client.create_collection("ecological_papers")
                self.use_chromadb = True
            except Exception:
                self.collection = SimpleVectorStore()
                self.use_chromadb = False
        else:
            self.collection = SimpleVectorStore()
            self.use_chromadb = False

        # Setup OpenAI
        if openai_api_key and OPENAI_AVAILABLE:
            openai.api_key = openai_api_key
            self.use_openai = True
        else:
            self.use_openai = False

    def _setup_tfidf(self):
        """Setup TF-IDF as fallback"""
        self.use_transformers = False
        self.tfidf = TfidfVectorizer(max_features=1000, stop_words='english')

    def preprocess_text(self, text):
        """Clean and prepare text for processing"""
        if not text:
            return ""
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'[^\w\s\-\.\(\)]', ' ', text)
        return text.strip()

    def extract_entities(self, text):
        """Extract ecological entities from text"""
        entities = {'species': [], 'locations': [], 'methods': []}

        # Species (binomial nomenclature)
        species = re.findall(r'\b[A-Z][a-z]+ [a-z]+\b', text)
        entities['species'] = list(set(species))[:3]

        # Locations
        locations = re.findall(
            r'\b(Mediterranean|Red Sea|Lake Kinneret|Eastern Mediterranean|Levantine)\b',
            text, re.IGNORECASE
        )
        entities['locations'] = list(set(locations))[:3]

        # Methods
        methods = re.findall(
            r'\b(PCR|DNA|sequencing|survey|analysis|modeling)\b',
            text, re.IGNORECASE
        )
        entities['methods'] = list(set(methods))[:3]

        return entities

    def generate_embeddings(self, texts):
        """Generate embeddings using available method"""
        if self.use_transformers:
            return self.embedding_model.encode(texts, show_progress_bar=False)
        else:
            if not self.fitted:
                self.tfidf.fit(texts)
                self.fitted = True
            return self.tfidf.transform(texts).toarray()

    def load_papers(self, papers_data):
        """Load papers into the RAG system"""
        valid_papers = [p for p in papers_data if p.get('abstract', '').strip()]

        if not valid_papers:
            return False

        documents, metadatas, ids = [], [], []

        for i, paper in enumerate(valid_papers):
            text = f"{paper.get('title', '')} {paper.get('abstract', '')}"
            text = self.preprocess_text(text)

            if len(text) < 50:
                continue

            entities = self.extract_entities(text)

            metadata = {
                'title': paper.get('title', 'Unknown'),
                'authors': paper.get('authors', 'Unknown'),
                'journal': paper.get('journal', 'Unknown'),
                'year': paper.get('year', 2022),
                'doi': paper.get('doi', ''),
                'species': ', '.join(entities['species']),
                'locations': ', '.join(entities['locations']),
                'methods': ', '.join(entities['methods'])
            }

            documents.append(text)
            metadatas.append(metadata)
            ids.append(f"paper_{i}")

        if not documents:
            return False

        # Generate embeddings
        embeddings = self.generate_embeddings(documents)

        # Add to vector store
        if getattr(self, 'use_chromadb', False):
            try:
                self.collection.add(
                    embeddings=embeddings.tolist(),
                    documents=documents,
                    metadatas=metadatas,
                    ids=ids
                )
            except Exception:
                # Fallback in case of unexpected API mismatch
                self.collection = SimpleVectorStore()
                self.collection.add(
                    embeddings=embeddings,
                    documents=documents,
                    metadatas=metadatas,
                    ids=ids
                )
        else:
            self.collection.add(
                embeddings=embeddings,
                documents=documents,
                metadatas=metadatas,
                ids=ids
            )

        self.papers = valid_papers
        return True

    def search(self, query, n_results=3):
        """Search for relevant papers"""
        query_processed = self.preprocess_text(query)
        query_embedding = self.generate_embeddings([query_processed])

        if getattr(self, 'use_chromadb', False):
            results = self.collection.query(
                query_embeddings=query_embedding.tolist(),
                n_results=n_results
            )
        else:
            results = self.collection.query(
                query_embeddings=query_embedding,
                n_results=n_results
            )

        return results

    def _generate_openai_response(self, query, papers, search_results):
        """Generate response using OpenAI"""
        context = "\n\n".join([
            f"Paper: {papers[i]['title']}\n"
            f"Authors: {papers[i]['authors']}\n"
            f"Content: {search_results['documents'][0][i][:400]}..."
            for i in range(min(len(search_results['documents'][0]), len(papers)))
        ])

        prompt = f"""You are an expert marine ecologist. Answer this question based on the research provided:

Question: {query}

Research Papers:
{context}

Provide a comprehensive answer citing the research. Focus on Mediterranean and freshwater ecosystems."""

        try:
            response = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",
                messages=[
                    {"role": "system", "content": "You are an expert marine and freshwater ecologist."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=800,
                temperature=0.7
            )
            return response.choices[0].message.content
        except Exception:
            return self._generate_template_response(query, papers, search_results)

    def _generate_template_response(self, query, papers, search_results):
        """Generate template response without OpenAI"""
        response = f"## Research Findings for: {query}\n\n"

        for i, paper in enumerate(papers[:5]):
            response += f"### {i+1}. {paper['title']}\n"
            response += f"**Authors:** {paper['authors']}\n"
            response += f"**Journal:** {paper['journal']} ({paper['year']})\n"

            if paper.get('species'):
                response += f"**Species:** {paper['species']}\n"
            if paper.get('locations'):
                response += f"**Locations:** {paper['locations']}\n"
            if paper.get('methods'):
                response += f"**Methods:** {paper['methods']}\n"

            response += f"**DOI:** {paper['doi']}\n\n"

        # Add summary
        all_species = set()
        all_locations = set()
        for paper in papers:
            if paper.get('species'):
                all_species.update([s.strip() for s in paper['species'].split(',') if s.strip()])
            if paper.get('locations'):
                all_locations.update([l.strip() for l in paper['locations'].split(',') if l.strip()])

        response += "### Summary\n"
        if all_species:
            response += f"**Key Species:** {', '.join(list(all_species)[:5])}\n"
        if all_locations:
            response += f"**Study Regions:** {', '.join(list(all_locations))}\n"

        return response

    def generate_response(self, query, search_results):
        """Generate response based on search results"""
        if not search_results['documents'][0]:
            return "No relevant papers found for your query."

        papers = search_results['metadatas'][0]

        if getattr(self, 'use_openai', False):
            return self._generate_openai_response(query, papers, search_results)
        else:
            return self._generate_template_response(query, papers, search_results)

    def query(self, question, n_results=5):
        """Main query function"""
        search_results = self.search(question, n_results)
        response = self.generate_response(question, search_results)

        return {
            'question': question,
            'response': response,
            'papers_found': len(search_results['documents'][0]),
            'search_results': search_results
        }

    def get_status(self):
        """Get system status"""
        return {
            'vector_db': 'ChromaDB' if getattr(self, 'use_chromadb', False) else 'Simple Store',
            'embeddings': 'Transformer' if getattr(self, 'use_transformers', False) else 'TF-IDF',
            'generation': 'OpenAI GPT' if getattr(self, 'use_openai', False) else 'Template',
            'papers_loaded': len(self.papers) if self.papers else 0
        }

In [7]:
def get_sample_iolr_papers():
    return [
       {
    'title': 'BTEX and PAH contributions to Lake Kinneret water: a seasonal-based study of volatile and semi-volatile anthropogenic pollutants in freshwater sources',
    'authors': 'Astrahan, P., Lupu, A., Leibovici, E., and S. Ninio.',
    'journal': 'Environmental Science and Pollution Research',
    'year': 2023,
    'doi': 'https://doi.org/10.1007/s11356-023-26724-9',
    'abstract': 'BTEX and PAH contributions to Lake Kinneret water: a seasonal-based study of volatile and semi-volatile anthropogenic pollutants in freshwater sources. Environmental Science and Pollution Research, 30(21):61145-61159'
},
{
    'title': 'Sodium levels and grazing pressure shape natural communities of the intracellular pathogen Legionella',
    'authors': 'Bergman, O., Beeri-Shlevin, Y., and S. Ninio.',
    'journal': 'Microbiome',
    'year': 2023,
    'doi': 'https://doi.org/10.1186/s40168-023-01611-0',
    'abstract': 'Sodium levels and grazing pressure shape natural communities of the intracellular pathogen Legionella. Microbiome, 11:167'
},
{
    'title': 'Anthropogenic and natural disturbances along a river and its estuary alter the diversity of pathogens and antibiotic resistance mechanisms',
    'authors': 'Rubin-Blum M., Harbuzov Z., Cohen R., and P. Astrahan.',
    'journal': 'Science of The Total Environment',
    'year': 2023,
    'doi': 'https://doi.org/10.1016/j.scitotenv.2023.164108',
    'abstract': 'Anthropogenic and natural disturbances along a river and its estuary alter the diversity of pathogens and antibiotic resistance mechanisms. Science of The Total Environment, 887:164108'
},
{
    'title': 'The microbial community spatially varies during a Microcystis bloom event in Lake Kinneret',
    'authors': 'Schweitzer-Natan, O., Ofek-Lalzar, M., Sher, D., and A. Sukenik.',
    'journal': 'Freshwater Biology',
    'year': 2023,
    'doi': 'https://doi.org/10.1111/fwb.14030',
    'abstract': 'The microbial community spatially varies during a Microcystis bloom event in Lake Kinneret. Freshwater Biology, 68(2):349-363'
},
{
    'title': 'Upstream nitrogen availability determines the Microcystis salt tolerance and influences microcystins release in brackish water',
    'authors': 'Li X., Li L., Huang Y., Wu H., Sheng S., Jiang X., Chen X., and I. Ostrovsky.',
    'journal': 'Water Research',
    'year': 2024,
    'doi': 'https://doi.org/10.1016/j.watres.2024.121213',
    'abstract': 'Upstream nitrogen availability determines the Microcystis salt tolerance and influences microcystins release in brackish water. Water Research, 252:121213'
}
    ]

In [8]:
def create_premium_interface(rag_system):
    """Create elegant Gradio web interface with improved dark theme"""

    def query_papers(question, n_results):
        if not question.strip():
            return "Please enter a research question to get started."
        result = rag_system.query(question, n_results=int(n_results))
        return f"### 🤖 AI Answer\n{result['response']}"

    def get_system_metrics():
        status = rag_system.get_status()
        return f"""
        **System Configuration**

        Vector Database: {status['vector_db']} 🗄️
        Embeddings: {status['embeddings']} 🔮
        Generation: {status['generation']} 🤖
        Papers Loaded: {status['papers_loaded']} 📚
        """

    # Curated example questions
    examples = [
        ("How do BTEX and PAH pollutants contaminate freshwater sources?", 3),
        ("What seasonal patterns exist for volatile pollutants in Lake Kinneret?", 3),
        ("How do anthropogenic disturbances affect pathogen diversity in rivers?", 4),
        ("What antibiotic resistance mechanisms emerge from environmental pollution?", 3),
        ("How do semi-volatile pollutants impact aquatic ecosystem health?", 3),
    ]

    def run_example(idx: int):
        q, n = examples[idx]
        res = rag_system.query(q, n_results=int(n))
        return q, int(n), res["response"]

    # Custom CSS with fix for prompt textbox
    custom_css = """
    :root {
        --primary-color: #2563eb;
        --secondary-color: #1e40af;
        --background-color: #0f172a;
        --surface-color: #1e293b;
        --surface-light: #334155;
        --text-primary: #f8fafc;
        --text-secondary: #cbd5e1;
        --accent-color: #3b82f6;
        --success-color: #059669;
        --border-color: #475569;
    }
    .gradio-container { background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%); color: var(--text-primary); font-family: 'Inter','Segoe UI',sans-serif; min-height: 100vh; }
    .main-header { background: rgba(30,41,59,.8); backdrop-filter: blur(20px); border-radius: 16px; padding: 2rem; margin: 1rem; box-shadow: 0 8px 32px rgba(0,0,0,.3); border: 1px solid rgba(71,85,105,.3); }
    .query-section { background: rgba(30,41,59,.6); backdrop-filter: blur(15px); border-radius: 12px; padding: 1.5rem; box-shadow: 0 4px 20px rgba(0,0,0,.2); border: 1px solid rgba(71,85,105,.2); }
    .response-area { background: rgba(30,41,59,.7) !important; border-radius: 12px !important; border: 1px solid rgba(71,85,105,.3) !important; box-shadow: 0 4px 20px rgba(0,0,0,.15) !important; color: var(--text-primary) !important; }
    .status-panel { background: linear-gradient(135deg,#1e40af 0%,#3730a3 100%); border-radius: 12px; padding: 1rem; color: white; box-shadow: 0 4px 15px rgba(59,130,246,.2); border: 1px solid rgba(147,197,253,.2); }

    .gradio-textbox textarea, .gradio-textbox input { background: rgba(51,65,85,.8) !important; border: 1px solid rgba(71,85,105,.4) !important; border-radius: 8px !important; color: var(--text-primary) !important; font-size: 14px !important; }
    .gradio-textbox textarea:focus, .gradio-textbox input:focus { border-color: var(--accent-color) !important; box-shadow: 0 0 0 2px rgba(59,130,246,.2) !important; }
    .gradio-button { background: linear-gradient(135deg,var(--primary-color) 0%,var(--secondary-color) 100%) !important; border: none !important; border-radius: 8px !important; color: white !important; font-weight: 600 !important; transition: all .3s ease !important; }
    .gradio-button:hover { transform: translateY(-1px) !important; box-shadow: 0 6px 20px rgba(37,99,235,.3) !important; }
    .gradio-slider input[type="range"] { background: rgba(51,65,85,.8) !important; }
    .gradio-accordion { background: rgba(30,41,59,.5) !important; border: 1px solid rgba(71,85,105,.3) !important; border-radius: 8px !important; }

    .examples-grid { display: grid; grid-template-columns: 1fr; gap: .5rem; }
    @media (min-width: 900px) {
      .examples-grid { grid-template-columns: 1fr 1fr; }
    }
    .example-btn { text-align: left; white-space: normal; line-height: 1.3; padding: .75rem 1rem; }
    .example-meta { opacity: .85; font-size: .9rem; }
    .markdown-content h1, .markdown-content h2, .markdown-content h3 { color: var(--text-primary) !important; }
    .markdown-content p { color: var(--text-secondary) !important; line-height: 1.6 !important; }
    .gradio-label { color: var(--text-primary) !important; font-weight: 500 !important; }

    /* Force black text ONLY in the prompt textbox */
    #prompt_box textarea,
    #prompt_box input[type="text"] {
        color: black !important;
        caret-color: black !important;
    }
    #prompt_box textarea::placeholder,
    #prompt_box input[type="text"]::placeholder {
        color: #111 !important;
        opacity: .6 !important;
    }
    """

    with gr.Blocks(
        title="Ecological Research Assistant",
        theme=gr.themes.Base(
            primary_hue="blue",
            secondary_hue="slate",
            neutral_hue="slate",
            font=gr.themes.GoogleFont("Inter")
        ).set(
            body_background_fill="#0f172a",
            body_text_color="#f8fafc",
            background_fill_primary="#1e293b",
            background_fill_secondary="#334155",
            border_color_primary="#475569",
            color_accent="#3b82f6",
            color_accent_soft="#1e40af"
        ),
        css=custom_css
    ) as interface:

        with gr.Column(elem_classes="main-header"):
            gr.HTML("""
            <div style="text-align: center;">
                <h1 style="background: linear-gradient(45deg,#60a5fa,#3b82f6);
                           -webkit-background-clip: text;
                           -webkit-text-fill-color: transparent;
                           font-size: 2.5rem; margin: 0; font-weight: 700;">
                    Ecological Research Assistant 🌊
                </h1>
                <p style="font-size: 1.1rem; color: #94a3b8; margin-top: .5rem; font-weight: 400;">
                    AI-powered insights from IOLR marine and freshwater research
                </p>
            </div>
            """)

        with gr.Row():
            with gr.Column(scale=3, elem_classes="query-section"):
                question_input = gr.Textbox(
                    label="🔍 Research Question",
                    placeholder="Ask about subjects in pollution that interest you",
                    lines=3,
                    elem_id="prompt_box",   # ⬅️ הוספה
                )
                with gr.Row():
                    n_results_slider = gr.Slider(
                        minimum=1, maximum=5, value=5, step=1,
                        label="Papers to analyze 📄 ",
                    )
                    submit_btn = gr.Button("Analyze Research 🚀 ", variant="primary")

            with gr.Column(scale=1, elem_classes="status-panel"):
                system_info = gr.Markdown(get_system_metrics())

        response_output = gr.Markdown(label="📊 Research Analysis", elem_classes="response-area", show_label=True)

        with gr.Accordion("💡 Example Questions", open=False):
            gr.HTML('<div class="example-meta">Try these research questions:</div>')
            with gr.Column(elem_classes="examples-grid"):
                buttons = []
                for i, (q, n) in enumerate(examples):
                    btn = gr.Button(f"🔎 {q}  ·  {n} papers", elem_classes="example-btn")
                    btn.click(
                        fn=lambda idx=i: run_example(idx),
                        inputs=[],
                        outputs=[question_input, n_results_slider, response_output],
                    )
                    buttons.append(btn)

        submit_btn.click(fn=query_papers, inputs=[question_input, n_results_slider], outputs=response_output)
        question_input.submit(fn=query_papers, inputs=[question_input, n_results_slider], outputs=response_output)

    return interface


In [9]:
def initialize_system():
    rag_system = EcologicalRAG(openai_api_key="OPENAI-API-KEY")
    sample_papers = get_sample_iolr_papers()
    rag_system.load_papers(sample_papers)
    return rag_system

In [10]:
def find_available_port(start_port=7860, max_attempts=10):
    import socket
    for port in range(start_port, start_port + max_attempts):
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(('localhost', port))
                return port
        except OSError:
            continue
    return None

In [11]:
def launch_app():
    rag_system = initialize_system()
    if GRADIO_AVAILABLE:
        interface = create_premium_interface(rag_system)
        port = find_available_port()
        if port is None:
            interface.launch(share=True, show_error=False, quiet=True)
        else:
            interface.launch(share=True, server_port=port, show_error=False, quiet=True)
    else:
        # If gradio isn't available, just return the system for programmatic use
        return rag_system

In [12]:
# Run the app (uncomment to launch Gradio)
if __name__ == "__main__":
    launch_app()

* Running on public URL: https://0c1e40d1159e350e1a.gradio.live
