In [None]:
import json
import os
import requests
import time
from pathlib import Path
from typing import Dict, List, Any
from dataclasses import dataclass

@dataclass
class ParaphraseConfig:
    num_variations: int = 10
    delay_between_requests: int = 1
    llm_url: str = "http://10.127.30.113:11434/api/generate"
    model_name: str = "llama3.1"
    output_dir: str = "/share/ssddata/sarimhashmi/corrected_iuxray/corrected_iuxray"

class RadiologyQuestionParaphraser:
    def __init__(self, config: ParaphraseConfig):
        self.config = config
        os.makedirs(self.config.output_dir, exist_ok=True)
        
    def query_llm(self, prompt: str) -> str:
        payload = {
            "model": self.config.model_name,
            "prompt": prompt
        }
        full_response = ""
        try:
            response = requests.post(self.config.llm_url, json=payload, stream=True)
            for line in response.iter_lines():
                if line:
                    json_response = json.loads(line)
                    if 'response' in json_response:
                        full_response += json_response['response']
                    if json_response.get('done', False):
                        break
            return full_response
        except Exception as e:
            print(f"Error querying LLM: {e}")
            return ""

    def is_valid_variation(self, text: str) -> bool:
        """Check if the variation is valid."""
        if len(text) < 20:  # Too short to be a valid question
            return False
        if text.count('<image>') != 1:  # Should have exactly one <image> tag
            return False
        if '[yes, no]' not in text:  # Must include the options
            return False
        if text == '<image>' or text.strip() == '<image>':  # Invalid variation
            return False
        return True

    def clean_variation(self, text: str, original_question: str) -> str:
        """Clean and standardize the variation text."""
        text = text.strip()
        
        # Remove numbering prefixes
        for i in range(1, self.config.num_variations + 1):
            if text.startswith(f"{i}."):
                text = text[len(str(i))+1:].strip()
        
        # Remove formatting and prefixes
        text = text.replace('*', '').strip()
        prefixes_to_remove = [
            "Original:", "Variation:", "Alternative:", "Query:", 
            "Technical term:", "Focused query:", "Clinical context:"
        ]
        for prefix in prefixes_to_remove:
            if text.startswith(prefix):
                text = text[len(prefix):].strip()
                
        # Ensure proper format
        if not text.endswith("\n<image>"):
            text = text.rstrip() + "\n<image>"
            
        if "[yes, no]" not in text:
            text = text.replace("\n<image>", " Please choose from the following two options: [yes, no]\n<image>")
        
        # Validate the cleaned variation
        if not self.is_valid_variation(text):
            return original_question  # Return original question if variation is invalid
            
        return text

    def generate_prompt(self, question: str) -> str:
        return f'''You are an expert radiologist. Generate exactly {self.config.num_variations} different ways to ask this radiology question. Each variation must:
1. Keep the exact same medical meaning but use different phrasing and terminology
2. Be a complete, standalone question (not just "<image>")
3. Use proper medical terminology appropriate for a radiology report
4. Each variation must be a proper question about the specific condition mentioned
5. Include "Please choose from the following two options: [yes, no]" at the end
6. End with "\n<image>"

Original question: "{question}"

Important: Each variation must be a complete question, not just "<image>". Make each variation unique and meaningful.

Provide only the variations without any additional text or numbering.'''

    def process_questions(self, input_file: str) -> None:
        with open(input_file, 'r') as f:
            lines = f.readlines()
        
        for idx, line in enumerate(lines, 1):
            if not line.strip():
                continue
                
            item = json.loads(line)
            question = item['question']
            original_question = question
            
            attempts = 0
            valid_variations = []
            
            # Try up to 3 times to get good variations
            while len(valid_variations) < self.config.num_variations and attempts < 3:
                prompt = self.generate_prompt(question)
                response = self.query_llm(prompt)
                
                variations = [v.strip() for v in response.split('\n') if v.strip()]
                
                for var in variations:
                    cleaned_var = self.clean_variation(var, original_question)
                    if self.is_valid_variation(cleaned_var) and cleaned_var not in valid_variations:
                        valid_variations.append(cleaned_var)
                
                attempts += 1
                time.sleep(self.config.delay_between_requests)
            
            # Fill remaining slots with original question if needed
            while len(valid_variations) < self.config.num_variations:
                valid_variations.append(original_question)
            
            # Create output dictionary
            output_item = {
                "question": question,
                "answer": item['answer'],
                "image": item['image']
            }
            
            # Add variations
            for i, variation in enumerate(valid_variations[:self.config.num_variations], 1):
                output_item[f"variation_{i}"] = variation
            
            # Write to individual file
            output_file = os.path.join(self.config.output_dir, f'question_{idx}_variants_results.json')
            with open(output_file, 'w') as f:
                json.dump(output_item, f, ensure_ascii=False, indent=2)
            
            print(f"Processed question {idx} -> {output_file}")

def main():
    config = ParaphraseConfig()
    paraphraser = RadiologyQuestionParaphraser(config)
    
    input_file = "/share/ssddata/sarimhashmi/corrected_iuxray/vanillah_iuxray_json.json"  # Change this to your input file path
    paraphraser.process_questions(input_file)

if __name__ == "__main__":
    main()