In [None]:
!pip install -q transformers accelerate bitsandbytes gradio torch pillow

import torch
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import gradio as gr
from PIL import Image
import re
from typing import List, Tuple

class RiverPollutionAnalyzerInstructBLIP:
    def __init__(self):
        # Initialize model with 4-bit quantization for memory efficiency
        self.processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
        self.model = InstructBlipForConditionalGeneration.from_pretrained(
            "Salesforce/instructblip-vicuna-7b",
            device_map="auto",
            torch_dtype=torch.float16,
            load_in_4bit=True
        )

        self.pollutants = [
            "plastic waste", "chemical foam", "industrial discharge",
            "sewage water", "oil spill", "organic debris",
            "construction waste", "medical waste", "floating trash",
            "algal bloom", "toxic sludge", "agricultural runoff"
        ]

        self.severity_descriptions = {
            1: "Minimal pollution - Slightly noticeable",
            2: "Minor pollution - Small amounts visible",
            3: "Moderate pollution - Clearly visible",
            4: "Significant pollution - Affecting water quality",
            5: "Heavy pollution - Obvious environmental impact",
            6: "Severe pollution - Large accumulation",
            7: "Very severe pollution - Major ecosystem impact",
            8: "Extreme pollution - Dangerous levels",
            9: "Critical pollution - Immediate action needed",
            10: "Disaster level - Ecological catastrophe"
        }

        self.current_image = None
        self.current_pollutants = []
        self.severity = None
        self.chat_history = []

    def analyze_image(self, image):
        """Analyze river pollution using InstructBLIP"""
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)

        self.current_image = image
        self.chat_history = []

        # More specific prompt for better pollutant detection and severity assessment
        prompt = """Carefully analyze this river pollution scene and provide:
1. List ALL visible pollutants from this list: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff]
2. Estimate pollution severity from 1-10 based on: pollutant types, coverage area, and water discoloration

Format your response EXACTLY like this:
Pollutants: [comma separated list]
Severity: [number between 1-10]"""

        inputs = self.processor(
            images=image,
            text=prompt,
            return_tensors="pt"
        ).to("cuda", torch.float16)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=200,
                temperature=0.5,  # Lower temperature for more factual responses
                top_p=0.85,
                do_sample=True
            )

        analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]

        # Improved parsing with better error handling
        pollutants = []
        severity = 3  # Default moderate level

        # Extract pollutants
        if "Pollutants:" in analysis:
            pollutants_str = analysis.split("Pollutants:")[1].split("Severity:")[0].strip()
            # Match only from our predefined list
            pollutants = [p.strip().lower() for p in pollutants_str.split(",")
                         if p.strip().lower() in self.pollutants]
            self.current_pollutants = [(p, "N/A") for p in pollutants[:5]]  # Store top 5 pollutants

        # Extract severity with better validation
        if "Severity:" in analysis:
            try:
                severity_match = re.search(r'Severity:\s*(\d+)', analysis)
                if severity_match:
                    severity = int(severity_match.group(1))
                    # Ensure severity is within 1-10 range
                    severity = max(1, min(10, severity))
            except:
                severity = self._calculate_severity(pollutants)
        else:
            severity = self._calculate_severity(pollutants)

        self.severity = severity

        return self._format_analysis(pollutants)

    def _calculate_severity(self, pollutants: List[str]) -> int:
        """Calculate severity based on pollutant types and quantities"""
        if not pollutants:
            return 1  # No pollutants = minimal severity

        # Weight different pollutants differently
        pollutant_weights = {
            "medical waste": 3,
            "toxic sludge": 3,
            "oil spill": 2.5,
            "chemical foam": 2,
            "industrial discharge": 2,
            "sewage water": 2,
            "plastic waste": 1.5,
            "construction waste": 1.5,
            "floating trash": 1,
            "organic debris": 1,
            "algal bloom": 1.5,
            "agricultural runoff": 1.5
        }

        total_weight = sum(pollutant_weights.get(p, 1) for p in pollutants)
        avg_weight = total_weight / len(pollutants)

        # Map to severity scale
        if avg_weight < 1.2:
            return 2
        elif avg_weight < 1.6:
            return 4
        elif avg_weight < 2.0:
            return 6
        elif avg_weight < 2.5:
            return 8
        else:
            return 10

    def _format_analysis(self, pollutants: List[str]) -> str:
        """Generate analysis report focusing on pollutants and severity"""
        # Generate severity bar
        severity_bar = ""
        if self.severity:
            filled = "█" * self.severity
            empty = "░" * (10 - self.severity)
            severity_bar = f"""
📊 **Severity Level**: {self.severity}/10
{filled}{empty}
{self.severity_descriptions.get(self.severity, '')}
"""

        # Format pollutants list
        pollutants_list = ""
        if pollutants:
            pollutants_list = "🔍 **Detected Pollutants**:\n"
            for idx, pollutant in enumerate(pollutants[:5], 1):
                pollutants_list += f"{idx}. {pollutant.capitalize()}\n"
        else:
            pollutants_list = "🔍 No significant pollutants detected\n"

        # Build the final report
        analysis = f"""
🌊 **River Pollution Analysis** 🌊

{severity_bar}
{pollutants_list}
"""
        return analysis

    def chat_response(self, message):
        """Handle chat questions with context (unchanged)"""
        if not self.current_image:
            return "Please analyze an image first!", self.chat_history

        # Build context
        context = f"""Current analysis:
- Pollutants: {', '.join([p[0] for p in self.current_pollutants[:3]])}
- Severity: {self.severity}/10 ({self.severity_descriptions.get(self.severity, '')})
"""

        if self.chat_history:
            context += "\nPrevious conversation:\n" + "\n".join(
                f"User: {q}\nAssistant: {a}" for q, a in self.chat_history[-2:]
            )

        prompt = f"""{context}
New question: {message}
Answer specifically about the pollution, considering:
1. Visual evidence in the image
2. Environmental impact
3. Possible solutions
Assistant:"""

        inputs = self.processor(
            images=self.current_image,
            text=prompt,
            return_tensors="pt"
        ).to("cuda", torch.float16)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=250,
                temperature=0.7,
                repetition_penalty=1.2
            )

        response = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]

        # Clean response
        if "Assistant:" in response:
            response = response.split("Assistant:")[-1].strip()

        self.chat_history.append((message, response))
        return "", self.chat_history

# Initialize analyzer
try:
    analyzer = RiverPollutionAnalyzerInstructBLIP()
except Exception as e:
    print(f"Error initializing analyzer: {e}")
    raise

# Gradio Interface (unchanged)
css = """
.gradio-container {
    max-width: 900px !important;
}
.analysis-box {
    border: 1px solid #e0e0e0;
    border-radius: 8px;
    padding: 15px;
    background: #f9f9f9;
}
.chatbot {
    min-height: 500px;
}
.dark .analysis-box {
    background: #2a2a2a;
    border-color: #444;
}
footer {
    display: none !important;
}
.severity-bar {
    font-family: monospace;
    font-size: 16px;
    line-height: 1.5;
}
"""

with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
    gr.Markdown("""
    # 🌍 River Pollution Analyzer (InstructBLIP)
    *Upload an image of a river to analyze pollution levels using InstructBLIP model*
    """)

    with gr.Row():
        with gr.Column(scale=1, min_width=300):
            with gr.Group():
                image_input = gr.Image(
                    type="pil",
                    label="Upload River Image",
                    elem_classes="image-upload"
                )
                analyze_btn = gr.Button(
                    "Analyze Pollution",
                    variant="primary",
                    size="lg"
                )

            with gr.Group(elem_classes="analysis-box"):
                gr.Markdown("### 🔬 Analysis Results")
                analysis_output = gr.Markdown(
                    label="",
                    elem_classes="analysis-results"
                )

        with gr.Column(scale=2, min_width=500):
            chatbot = gr.Chatbot(
                label="Pollution Analysis Chat",
                bubble_full_width=False,
                height=500,
                elem_classes="chatbot"
            )

            with gr.Row():
                chat_input = gr.Textbox(
                    placeholder="Ask about pollution sources, environmental impact, or cleanup recommendations...",
                    label="Your Question",
                    container=False,
                    scale=5
                )
                chat_btn = gr.Button(
                    "Ask",
                    variant="secondary",
                    scale=1,
                    min_width=100
                )

            with gr.Row():
                clear_btn = gr.Button("Clear Chat History", size="sm")
                gr.Markdown("*Tip: Ask specific questions for detailed answers*", elem_classes="tip-text")

    # Analysis action
    analyze_btn.click(
        analyzer.analyze_image,
        inputs=image_input,
        outputs=analysis_output
    )

    # Chat actions
    chat_msg = chat_input.submit(
        analyzer.chat_response,
        inputs=chat_input,
        outputs=[chat_input, chatbot]
    )
    chat_btn.click(
        analyzer.chat_response,
        inputs=chat_input,
        outputs=[chat_input, chatbot]
    )

    # Clear chat
    clear_btn.click(
        lambda: ([], []),
        outputs=[chat_input, chatbot]
    )

    # # # Examples
    # # gr.Examples(
    # #     examples=[
    # #         ["https://drive.google.com/uc?export=download&id=1mLLW0ilH2o4MEwQ6Ohm8Hv6q9IAxe1AS"],  # Polluted river
    # #         ["https://drive.google.com/uc?export=download&id=17O8Ro5ZIJ6G-IXPUVH5s-5Jpd8uUCcmK"]   # Clean river
    # #     ],
    # #     inputs=image_input,
    # #     outputs=analysis_output,
    # #     fn=analyzer.analyze_image,
    # #     cache_examples=True,
    # #     label="Try example images:"
    # )

# Launch the demo
demo.launch(debug=True, share=True)