In [2]:
# Gemma 3N E2B Multimodal System - Text, Images, Audio
# pip install transformers torch accelerate pillow

from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import torch
import json
import re
from typing import Dict, List, Tuple, Optional
from pathlib import Path
def do_gemma_3n_inference(model, messages, max_new_tokens=256, temperature=0.7):
    """
    Inference function for Gemma 3N with multimodal support
    """
    try:
        # Get the processor from the model
        processor = model.processor if hasattr(model, 'processor') else None
        
        # If no processor available, use a simple text response
        if not processor:
            print("⚠️ No processor found, returning placeholder response")
            return "I'm a placeholder response since the model processor isn't available yet."
        
        # Process the messages
        # Extract text and images from messages
        text_parts = []
        images = []
        
        for message in messages:
            if message.get("role") == "user":
                content = message.get("content", [])
                for item in content:
                    if item.get("type") == "text":
                        text_parts.append(item.get("text", ""))
                    elif item.get("type") == "image":
                        # Load image if it's a path
                        img_path = item.get("image")
                        if img_path and Path(img_path).exists():
                            try:
                                image = Image.open(img_path)
                                images.append(image)
                            except Exception as e:
                                print(f"⚠️ Could not load image {img_path}: {e}")
        
        # Combine text parts
        prompt = " ".join(text_parts)
        
        # Process inputs
        if images:
            # Process with images
            inputs = processor(
                text=prompt,
                images=images,
                return_tensors="pt"
            )
        else:
            # Text only
            inputs = processor(
                text=prompt,
                return_tensors="pt"
            )
        
        # Move inputs to the same device as model
        device = next(model.parameters()).device
        inputs = {k: v.to(device) if torch.is_tensor(v) else v for k, v in inputs.items()}
        
        # Generate response
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                pad_token_id=processor.tokenizer.eos_token_id
            )
        
        # Decode response
        response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
        
        # Remove the original prompt from response if it's included
        if prompt in response:
            response = response.replace(prompt, "").strip()
        
        return response
        
    except Exception as e:
        print(f"❌ Inference error: {e}")
        return f"Sorry, I encountered an error: {str(e)}"

class Gemma3NEducationSystem:
    def __init__(self, base_path: str = "../datasets/education"):
        """Initialize Gemma 3N multimodal education system"""
        self.base_path = Path(base_path)
        
        # Load structured data
        self.structured_data = self._load_json("education_structured_data_extract.json")
        
        # Load model
        self.model = None
        self.processor = None
        self._load_model()
        
        print(f"✅ System ready! Loaded {len(self.structured_data)} sections")
    
    def _load_json(self, filename: str):
        """Load JSON file"""
        try:
            file_path = self.base_path / filename
            if file_path.exists():
                with open(file_path, 'r', encoding='utf-8') as f:
                    return json.load(f)
            return []
        except Exception as e:
            print(f"Error loading {filename}: {e}")
            return []
    
    def _load_model(self):
        """Load Gemma 3N E2B model"""
        print("🔄 Loading Gemma 3N E2B...")
        
        try:
            model_id = "google/gemma-3n-E2B-it"
            
            self.processor = AutoProcessor.from_pretrained(model_id)
            self.model = AutoModelForVision2Seq.from_pretrained(
                model_id,
                torch_dtype=torch.float16,
                device_map="auto",
            )
            
            # Use MPS if available (Mac)
            if torch.backends.mps.is_available():
                self.model = self.model.to("mps")
                print("✅ Using Apple Silicon MPS")
            
            print("✅ Gemma 3N E2B loaded!")
            
        except Exception as e:
            print(f"❌ Model loading failed: {e}")
    
    def is_page_based_question(self, question: str) -> Tuple[bool, Optional[int]]:
        """Check if question mentions a specific page"""
        page_patterns = [
            r'(?:in|on|from|at)\s+page\s+(\d+)',
            r'page\s+(\d+)',
            r'p\.?\s*(\d+)',
            r'pg\.?\s*(\d+)'
        ]
        
        question_lower = question.lower()
        for pattern in page_patterns:
            match = re.search(pattern, question_lower)
            if match:
                return True, int(match.group(1))
        return False, None
    
    def get_page_info(self, page_number: int) -> Dict:
        """Get content and images for a specific page"""
        matching_sections = []
        all_images = []
        
        for section in self.structured_data:
            if not isinstance(section, dict):
                continue
            
            page_start = section.get('page_start', 0)
            page_end = section.get('page_end', 0)
            
            if page_start <= page_number <= page_end:
                matching_sections.append(section)
                
                # Get images for this page
                section_images = section.get('images', [])
                for img in section_images:
                    if isinstance(img, dict):
                        img_page = img.get('page', 0)
                        img_path = img.get('path', '')
                        if img_page == page_number and img_path:
                            all_images.append(img_path)
        
        if matching_sections:
            combined_text = "\n\n".join([
                f"{section.get('main_heading', '')} - {section.get('sub_heading', '')}\n{section.get('content', '')}"
                for section in matching_sections
            ])
            
            return {
                'text': combined_text,
                'images': all_images[:3],  # Max 3 images
                'sections': matching_sections
            }
        
        return {'text': f"No content found for page {page_number}", 'images': [], 'sections': []}
    
    def process_text_question(self, question: str) -> str:
        """Process text-only question"""
        messages = [{
            "role": "user",
            "content": [
                {"type": "text", "text": f"Question: {question}\n\nAnswer:"}
            ]
        }]
        
        return do_gemma_3n_inference(self.model, messages, max_new_tokens=300)
    
    def process_image_question(self, question: str, image_path: str) -> str:
        """Process question with image"""
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image_path},
                {"type": "text", "text": f"Question: {question}"}
            ]
        }]
        
        return do_gemma_3n_inference(self.model, messages, max_new_tokens=300)
    
    def process_audio_question(self, audio_file: str, question: str = "What is this audio about?") -> str:
        """Process audio question"""
        messages = [{
            "role": "user",  
            "content": [
                {"type": "audio", "audio": audio_file},
                {"type": "text", "text": question}
            ]
        }]
        
        return do_gemma_3n_inference(self.model, messages, max_new_tokens=300)
    
    def process_multimodal_question(self, question: str, image_paths: List[str] = None, audio_file: str = None) -> str:
        """Process question with multiple modalities"""
        content = []
        
        # Add audio if provided
        if audio_file:
            content.append({"type": "audio", "audio": audio_file})
        
        # Add images if provided
        if image_paths:
            for img_path in image_paths:
                content.append({"type": "image", "image": img_path})
        
        # Add text question
        content.append({"type": "text", "text": question})
        
        messages = [{"role": "user", "content": content}]
        
        return do_gemma_3n_inference(self.model, messages, max_new_tokens=400)
    
    def answer_general_question(self, question: str, audio_file: str = None) -> Dict:
        """Answer general question (with optional audio)"""
        if audio_file:
            answer = self.process_audio_question(audio_file, question)
        else:
            answer = self.process_text_question(question)
        
        return {
            'answer': answer,
            'type': 'general',
            'modalities': ['audio', 'text'] if audio_file else ['text']
        }
    
    def answer_page_based_question(self, question: str, page_number: int, audio_file: str = None) -> Dict:
        """Answer page-based question with context, images, and optional audio"""
        # Get page info
        page_info = self.get_page_info(page_number)
        
        if not page_info['text'] or 'No content found' in page_info['text']:
            return {
                'answer': f"No content found for page {page_number}",
                'type': 'page_based',
                'page_number': page_number
            }
        
        # Build context with page content
        context_question = f"Context: {page_info['text']}\n\nQuestion: {question}"
        
        # Process with multiple modalities
        answer = self.process_multimodal_question(
            context_question,
            image_paths=page_info['images'],
            audio_file=audio_file
        )
        
        modalities = ['text']
        if page_info['images']:
            modalities.append('images')
        if audio_file:
            modalities.append('audio')
        
        return {
            'answer': answer,
            'type': 'page_based',
            'page_number': page_number,
            'modalities': modalities,
            'images_used': len(page_info['images'])
        }
    
    def process_question(self, question: str, audio_file: str = None) -> Dict:
        """Main processing method"""
        print(f"🎯 Processing: {question}")
        if audio_file:
            print(f"🎤 With audio: {audio_file}")
        
        # Check question type
        is_page_based, page_number = self.is_page_based_question(question)
        
        if is_page_based and page_number:
            return self.answer_page_based_question(question, page_number, audio_file)
        else:
            return self.answer_general_question(question, audio_file)

# Test function
def test_all_modalities():
    """Test text, image, and audio"""
    system = Gemma3NEducationSystem("../datasets/education")
    
    # Test text only
    print("🤖 Testing text...")
    result1 = system.process_question("What is irrigation?")
    print(f"Answer: {result1['answer'][:100]}...")
    
    # Test page-based with images
    print("\n📖 Testing page with images...")
    result2 = system.process_question("In page 18, explain about sowing")
    print(f"Answer: {result2['answer'][:100]}...")
    print(f"Modalities: {result2.get('modalities', [])}")
    
    # Test with audio (if you have an audio file)
    # print("\n🎤 Testing with audio...")
    # result3 = system.process_question("Explain what you heard", audio_file="test_audio.mp3")
    # print(f"Answer: {result3['answer'][:100]}...")

if __name__ == "__main__":
    test_all_modalities()

🔄 Loading Gemma 3N E2B...
❌ Model loading failed: Unrecognized configuration class <class 'transformers.models.gemma3n.configuration_gemma3n.Gemma3nConfig'> for this kind of AutoModel: AutoModelForVision2Seq.
Model type should be one of BlipConfig, Blip2Config, ChameleonConfig, GitConfig, Idefics2Config, Idefics3Config, InstructBlipConfig, InstructBlipVideoConfig, Kosmos2Config, LlavaConfig, LlavaNextConfig, LlavaNextVideoConfig, LlavaOnevisionConfig, Mistral3Config, MllamaConfig, PaliGemmaConfig, Pix2StructConfig, Qwen2_5_VLConfig, Qwen2VLConfig, VideoLlavaConfig, VipLlavaConfig, VisionEncoderDecoderConfig.
✅ System ready! Loaded 0 sections
🤖 Testing text...
🎯 Processing: What is irrigation?
⚠️ No processor found, returning placeholder response
Answer: I'm a placeholder response since the model processor isn't available yet....

📖 Testing page with images...
🎯 Processing: In page 18, explain about sowing
Answer: No content found for page 18...
Modalities: []
