In [29]:
import json
import requests
import time
import re
from typing import Dict, List, Any
from collections import defaultdict

class RiceDiseaseQAGenerator:
    def __init__(self, ollama_url="http://localhost:11434", model_name="gemma3n:e4b"):
        self.ollama_url = ollama_url
        self.model_name = model_name
        self.generated_qa = []
        
    def format_disease_data(self, record: Dict) -> tuple:
        """Format rice disease data for prompt"""
        disease = record.get("disease", "")
        causal_organism = record.get("causal_organism", "")
        main_heading = record.get("main_heading", "")
        sub_heading = record.get("sub_heading", "")
        content = record.get("content", "")
        source = record.get("source", "Rice Diseases")
        
        return disease, causal_organism, main_heading, sub_heading, content, source
    
    def generate_record_prompt(self, record: Dict) -> str:
        """Generate prompt for record level (5 Q&A pairs)"""
        disease, causal_organism, main_heading, sub_heading, content, source = self.format_disease_data(record)
        
        prompt = f"""You are creating educational Q&A pairs for Indian rice farmers learning about rice diseases. Based on this rice disease information, generate exactly 5 question-answer pairs that would help Indian farmers understand and manage this disease.

Each question must reference rice farming, rice disease management, or rice cultivation in Indian agriculture. Include the causal organism name where relevant.

Rice Disease Information:
Disease: {disease}
Causal Organism: {causal_organism}
Main Topic: {main_heading}
Sub-topic: {sub_heading}
Content: {content}

Generate 5 Q&A pairs in this format:
Q1: [question]
A1: [answer]

Q2: [question]
A2: [answer]

Q3: [question]
A3: [answer]

Q4: [question]
A4: [answer]

Q5: [question]
A5: [answer]

Focus on practical rice farming scenarios and advice for Indian rice farmers. Include organism names in questions when relevant."""
        
        return prompt
    
    def generate_main_heading_prompt(self, heading_records: List[Dict]) -> str:
        """Generate prompt for main heading level (10 Q&A pairs)"""
        if not heading_records:
            return ""
            
        disease = heading_records[0].get("disease", "")
        main_heading = heading_records[0].get("main_heading", "")
        causal_organism = heading_records[0].get("causal_organism", "")
        
        # Combine all content from records under this main heading with sub-headings
        combined_content = ""
        subtopics = []
        
        for record in heading_records:
            sub_heading = record.get("sub_heading", "")
            content = record.get("content", "")
            if sub_heading:
                subtopics.append(sub_heading)
                combined_content += f"\n\n{sub_heading}: {content}"
            else:
                combined_content += f"\n\n{content}"
        
        prompt = f"""You are creating comprehensive educational Q&A pairs for Indian rice farmers about rice disease management. Based on this complete information about {main_heading} of {disease} disease, generate exactly 10 question-answer pairs.

Focus on rice cultivation in Indian conditions and practical farming scenarios. Include the causal organism name where relevant.

Rice Disease Information:
Disease: {disease}
Causal Organism: {causal_organism}
Main Topic: {main_heading}
Sub-topics covered: {', '.join(subtopics) if subtopics else 'General'}
Content: {combined_content}

Generate 10 Q&A pairs in this format:
Q1: [question]
A1: [answer]

Q2: [question]
A2: [answer]

Continue for all 10 pairs. Cover comprehensive aspects of {main_heading} for rice disease management in Indian agriculture. Include organism names in questions when relevant."""
        
        return prompt
    
    def generate_disease_prompt(self, disease_records: List[Dict]) -> str:
        """Generate prompt for disease level (20 Q&A pairs from all aspects)"""
        if not disease_records:
            return ""
            
        disease = disease_records[0].get("disease", "")
        causal_organism = disease_records[0].get("causal_organism", "")
        
        # Group records by main heading and organize with sub-headings
        headings_organized = defaultdict(list)
        all_headings = []
        
        for record in disease_records:
            main_heading = record.get("main_heading", "")
            if main_heading not in all_headings:
                all_headings.append(main_heading)
            headings_organized[main_heading].append(record)
        
        # Create comprehensive content with hierarchical structure
        full_content = ""
        all_subtopics = []
        
        for heading in all_headings:
            full_content += f"\n\n{heading}:"
            for record in headings_organized[heading]:
                sub_heading = record.get("sub_heading", "")
                content = record.get("content", "")
                if sub_heading:
                    all_subtopics.append(sub_heading)
                    full_content += f"\n  {sub_heading}: {content}"
                else:
                    full_content += f"\n  {content}"
        
        prompt = f"""You are creating comprehensive educational Q&A pairs for Indian rice farmers and agricultural experts. Based on this complete information about {disease} disease covering all aspects, generate exactly 20 question-answer pairs for complete disease management.

Focus on comprehensive rice disease management in Indian rice cultivation. Include the causal organism name where relevant.

Rice Disease Information:
Disease: {disease}
Causal Organism: {causal_organism}
Main Topics covered: {', '.join(all_headings)}
Sub-topics covered: {', '.join(all_subtopics) if all_subtopics else 'General'}
Complete Content: {full_content}

Generate 20 Q&A pairs in this format:
Q1: [question]
A1: [answer]

Q2: [question]
A2: [answer]

Continue for all 20 pairs. Cover complete {disease} disease management including identification, prevention, treatment, impact, resistance, monitoring, and integrated management for Indian rice farming. Include organism names in questions when relevant."""
        
        return prompt
    
    def query_ollama(self, prompt: str, max_retries: int = 10) -> str:
        """Query Ollama API with enhanced retry logic"""
        for attempt in range(max_retries):
            try:
                response = requests.post(
                    f"{self.ollama_url}/api/generate",
                    json={
                        "model": self.model_name,
                        "prompt": prompt,
                        "stream": False,
                        "options": {
                            "temperature": 0.7,
                            "top_p": 0.9,
                            "max_tokens": 4000
                        }
                    },
                    timeout=1000  # 60 seconds timeout
                )
                
                if response.status_code == 200:
                    result = response.json()
                    return result.get("response", "")
                else:
                    print(f"HTTP Error {response.status_code}: {response.text}")
                    
            except Exception as e:
                print(f"Attempt {attempt + 1}/{max_retries} failed: {e}")
                if attempt < max_retries - 1:
                    print(f"⏳ Waiting 60 seconds for Ollama to recover...")
                    time.sleep(60)  # Wait 60 seconds for Ollama to recover
                
        print(f"❌ Failed after {max_retries} attempts. Ollama may need manual restart.")
        return ""
    
    def parse_qa_response(self, response: str) -> List[Dict[str, str]]:
        """Parse Q&A pairs from response"""
        qa_pairs = []
        
        # Split by Q patterns
        questions = re.split(r'Q\d+:', response)[1:]
        
        for i, q_section in enumerate(questions, 1):
            # Split question and answer
            parts = re.split(r'A\d+:', q_section, 1)
            if len(parts) == 2:
                question = parts[0].strip()
                answer = parts[1].strip()
                
                # Clean up answer (remove next question if present)
                answer = re.split(r'Q\d+:', answer)[0].strip()
                
                if question and answer:
                    qa_pairs.append({
                        "question": question,
                        "answer": answer
                    })
        
        return qa_pairs
    
    def generate_qa_for_record(self, record: Dict) -> Dict:
        """Generate 5 Q&A pairs for a single record"""
        disease, causal_organism, main_heading, sub_heading, content, source = self.format_disease_data(record)
        
        record_identifier = f"{disease} - {main_heading}"
        if sub_heading:
            record_identifier += f" - {sub_heading}"
            
        print(f"\n🌾 Processing record: {record_identifier}")
        
        if not content.strip():
            print(f"⚠️  Skipping {record_identifier} - no content found")
            return None
        
        prompt = self.generate_record_prompt(record)
        response = self.query_ollama(prompt)
        
        if not response:
            print(f"❌ Failed to generate Q&A for {record_identifier}")
            return None
        
        qa_pairs = self.parse_qa_response(response)
        
        if not qa_pairs:
            print(f"❌ Failed to parse Q&A for {record_identifier}")
            return None
        
        print(f"✅ Generated {len(qa_pairs)} Q&A pairs for {record_identifier}")
        
        return {
            "disease": disease,
            "causal_organism": causal_organism,
            "main_heading": main_heading,
            "sub_heading": sub_heading,
            "qa_pairs": qa_pairs,
            "level": "record",
            "pages": record.get("pages", []),
            "content_length": len(content)
        }
    
    def generate_qa_for_main_heading(self, heading_records: List[Dict]) -> Dict:
        """Generate 10 Q&A pairs for main heading (combining all records under it)"""
        if not heading_records:
            return None
            
        disease = heading_records[0].get("disease", "")
        main_heading = heading_records[0].get("main_heading", "")
        causal_organism = heading_records[0].get("causal_organism", "")
        
        print(f"\n🔬 Processing main heading: {disease} - {main_heading}")
        print(f"📋 Combining {len(heading_records)} records")
        
        prompt = self.generate_main_heading_prompt(heading_records)
        response = self.query_ollama(prompt)
        
        if not response:
            print(f"❌ Failed to generate Q&A for {disease} - {main_heading}")
            return None
        
        qa_pairs = self.parse_qa_response(response)
        
        if not qa_pairs:
            print(f"❌ Failed to parse Q&A for {disease} - {main_heading}")
            return None
        
        print(f"✅ Generated {len(qa_pairs)} Q&A pairs for {disease} - {main_heading}")
        
        subtopics = [record.get("sub_heading", "") for record in heading_records if record.get("sub_heading")]
        
        return {
            "disease": disease,
            "causal_organism": causal_organism,
            "main_heading": main_heading,
            "sub_headings": subtopics,
            "qa_pairs": qa_pairs,
            "level": "main_heading",
            "record_count": len(heading_records),
            "total_content_length": sum(len(record.get("content", "")) for record in heading_records)
        }
    
    def generate_qa_for_disease(self, disease_records: List[Dict]) -> Dict:
        """Generate 20 Q&A pairs for entire disease (combining all aspects)"""
        if not disease_records:
            return None
            
        disease = disease_records[0].get("disease", "")
        causal_organism = disease_records[0].get("causal_organism", "")
        
        print(f"\n🦠 Processing complete disease: {disease}")
        print(f"📚 Combining {len(disease_records)} records from all aspects")
        
        prompt = self.generate_disease_prompt(disease_records)
        response = self.query_ollama(prompt)
        
        if not response:
            print(f"❌ Failed to generate Q&A for disease {disease}")
            return None
        
        qa_pairs = self.parse_qa_response(response)
        
        if not qa_pairs:
            print(f"❌ Failed to parse Q&A for disease {disease}")
            return None
        
        print(f"✅ Generated {len(qa_pairs)} Q&A pairs for complete disease {disease}")
        
        # Get all unique main headings and sub headings
        main_headings = list(set([record.get("main_heading", "") for record in disease_records if record.get("main_heading")]))
        sub_headings = list(set([record.get("sub_heading", "") for record in disease_records if record.get("sub_heading")]))
        
        return {
            "disease": disease,
            "causal_organism": causal_organism,
            "main_headings": main_headings,
            "sub_headings": sub_headings,
            "qa_pairs": qa_pairs,
            "level": "disease",
            "total_records": len(disease_records),
            "total_content_length": sum(len(record.get("content", "")) for record in disease_records)
        }
    
    def process_all_rice_diseases(self, json_file_path: str, output_file_path: str):
        """Process all rice disease records - generate Q&A at all 3 levels"""
        print(f"🌾 Loading rice disease data from: {json_file_path}")
        
        try:
            with open(json_file_path, 'r', encoding='utf-8') as f:
                records = json.load(f)
        except Exception as e:
            print(f"❌ Error loading JSON file: {e}")
            return
        
        print(f"📊 Found {len(records)} rice disease records to process")
        
        # Group records by disease and main_heading
        diseases_grouped = defaultdict(list)
        main_headings_grouped = defaultdict(list)
        
        for record in records:
            disease = record.get("disease", "Unknown")
            main_heading = record.get("main_heading", "Unknown")
            
            # Group by disease
            diseases_grouped[disease].append(record)
            
            # Group by disease + main_heading combination
            key = f"{disease}||{main_heading}"
            main_headings_grouped[key].append(record)
        
        print(f"🦠 Found {len(diseases_grouped)} rice diseases")
        print(f"📋 Found {len(main_headings_grouped)} main heading combinations")
        
        all_qa_data = []
        successful_records = 0
        successful_headings = 0
        successful_diseases = 0
        
        # Process each disease
        for disease, disease_records in diseases_grouped.items():
            print(f"\n🎯 Processing rice disease: {disease}")
            print(f"   └── Contains {len(disease_records)} total records")
            
            # Group this disease's records by main_heading
            disease_headings = defaultdict(list)
            for record in disease_records:
                main_heading = record.get("main_heading", "Unknown")
                disease_headings[main_heading].append(record)
            
            # 1. Generate Q&A for each record (5 each)
            for i, record in enumerate(disease_records):
                print(f"\n📍 Record progress: {i+1}/{len(disease_records)}")
                
                record_qa = self.generate_qa_for_record(record)
                if record_qa:
                    all_qa_data.append(record_qa)
                    successful_records += 1
                
                time.sleep(3)  # Rate limiting
            
            # 2. Generate Q&A for each main heading under this disease (10 each)
            for main_heading, heading_records in disease_headings.items():
                print(f"\n📋 Processing main heading: {disease} - {main_heading}")
                
                heading_qa = self.generate_qa_for_main_heading(heading_records)
                if heading_qa:
                    all_qa_data.append(heading_qa)
                    successful_headings += 1
                
                time.sleep(3)  # Rate limiting
            
            # 3. Generate Q&A for entire disease (20 total)
            disease_qa = self.generate_qa_for_disease(disease_records)
            if disease_qa:
                all_qa_data.append(disease_qa)
                successful_diseases += 1
            
            time.sleep(5)  # Longer pause between diseases
        
        # Save results
        try:
            with open(output_file_path, 'w', encoding='utf-8') as f:
                json.dump(all_qa_data, f, indent=2, ensure_ascii=False)
            
            print(f"\n🎉 SUCCESS!")
            print(f"✅ Processed {successful_records} individual records")
            print(f"✅ Processed {successful_headings} main headings") 
            print(f"✅ Processed {successful_diseases} complete diseases")
            print(f"💾 Saved to: {output_file_path}")
            
            # Calculate total Q&A pairs
            total_qa = sum(len(item["qa_pairs"]) for item in all_qa_data)
            print(f"📊 Total Q&A pairs generated: {total_qa}")
            
            # Expected vs actual
            print(f"\n📈 Expected totals:")
            print(f"   Records: {len(records)} × 5 = {len(records) * 5}")
            print(f"   Headings: {len(main_headings_grouped)} × 10 = {len(main_headings_grouped) * 10}")
            print(f"   Diseases: {len(diseases_grouped)} × 20 = {len(diseases_grouped) * 20}")
            print(f"   Total expected: {len(records) * 5 + len(main_headings_grouped) * 10 + len(diseases_grouped) * 20}")
            
        except Exception as e:
            print(f"❌ Error saving results: {e}")

# # Usage example:
# if __name__ == "__main__":
#     # Initialize the generator
#     qa_generator = RiceDiseaseQAGenerator()
    
#     # Process all rice disease records
#     input_file = "rice_diseases.json"  # Your input file
#     output_file = "rice_disease_qa_results.json"  # Output file
    
#     qa_generator.process_all_rice_diseases(input_file, output_file)
    
#     # Or test single record:
#     sample_record = {
#         "disease": "Blast",
#         "causal_organism": "Pyricularia oryzae", 
#         "main_heading": "Symptoms",
#         "sub_heading": "Early leaf symptoms",
#         "content": "All aboveground parts of the rice plant (leaves, leaf collar, culm, culm nodes, neck, and panicle) are attacked by the fungus initial symptoms are white to gray-green lesions or spots with brown borders Small specks originate on leaves - subsequently enlarge into spindle shaped spots(0.5 to 1.5cm length, 0.3 to 0.5cm width) with ashy center.",
#         "pages": [1, 2],
#         "images": [],
#         "source": "Rice Diseases"
#     }
    
#     # Test single record (5 Q&A)
#     print("🧪 Testing single record...")
#     test_result = qa_generator.generate_qa_for_record(sample_record)
#     if test_result:
#         print(f"✅ Test successful! Generated {len(test_result['qa_pairs'])} Q&A pairs")
#         for i, qa in enumerate(test_result['qa_pairs'][:2], 1):
#             print(f"\nQ{i}: {qa['question']}")
#             print(f"A{i}: {qa['answer'][:100]}...")
#     else:
#         print("❌ Test failed")

In [30]:
qa_generator = RiceDiseaseQAGenerator()

In [31]:
input_file = "/Users/saikumarallaka/kaggle/gemma_3n_impact_challenge/datasets/rice_diseases/rice_diseases_structured_data_extract.json"
output_file = "/Users/saikumarallaka/kaggle/gemma_3n_impact_challenge/datasets/rice_diseases/rice_disease_qa_results.json"

In [None]:
qa_generator.process_all_rice_diseases(input_file, output_file)

🌾 Loading rice disease data from: /Users/saikumarallaka/kaggle/gemma_3n_impact_challenge/datasets/rice_diseases/rice_diseases_structured_data_extract.json
📊 Found 69 rice disease records to process
🦠 Found 11 rice diseases
📋 Found 33 main heading combinations

🎯 Processing rice disease: Blast
   └── Contains 10 total records

📍 Record progress: 1/10

🌾 Processing record: Blast - Symptoms
✅ Generated 5 Q&A pairs for Blast - Symptoms

📍 Record progress: 2/10

🌾 Processing record: Blast - Symptoms - Leaf Blast
✅ Generated 5 Q&A pairs for Blast - Symptoms - Leaf Blast

📍 Record progress: 3/10

🌾 Processing record: Blast - Symptoms - Neck Blast
✅ Generated 5 Q&A pairs for Blast - Symptoms - Neck Blast

📍 Record progress: 4/10

🌾 Processing record: Blast - Symptoms - Nodal Blast
✅ Generated 5 Q&A pairs for Blast - Symptoms - Nodal Blast

📍 Record progress: 5/10

🌾 Processing record: Blast - Identification of pathogen
✅ Generated 5 Q&A pairs for Blast - Identification of pathogen

📍 Record pr

In [None]:
# # Usage example:
# if __name__ == "__main__":
#     # Initialize the generator
#     qa_generator = RiceDiseaseQAGenerator()
    
#     # Process all rice disease records
#     input_file = "rice_diseases.json"  # Your input file
#     output_file = "rice_disease_qa_results.json"  # Output file
    
#     qa_generator.process_all_rice_diseases(input_file, output_file)
    
    # # Or test single record:
    # sample_record = {
    #     "disease": "Blast",
    #     "causal_organism": "Pyricularia oryzae", 
    #     "main_heading": "Symptoms",
    #     "sub_heading": "",
    #     "content": "All aboveground parts of the rice plant (leaves, leaf collar, culm, culm nodes, neck, and panicle) are attacked by the fungus initial symptoms are white to gray-green lesions or spots with brown borders Small specks originate on leaves - subsequently enlarge into spindle shaped spots(0.5 to 1.5cm length, 0.3 to 0.5cm width) with ashy center.",
    #     "pages": [1, 2],
    #     "images": [],
    #     "source": "Rice Diseases"
    # }
    
    # # Test single record (5 Q&A)
    # print("🧪 Testing single record...")
    # test_result = qa_generator.generate_qa_for_record(sample_record)
    # if test_result:
    #     print(f"✅ Test successful! Generated {len(test_result['qa_pairs'])} Q&A pairs")
    #     for i, qa in enumerate(test_result['qa_pairs'][:2], 1):
    #         print(f"\nQ{i}: {qa['question']}")
    #         print(f"A{i}: {qa['answer'][:100]}...")
    # else:
    #     print("❌ Test failed")