In [None]:
import torch
from transformers import BertTokenizer, BertForQuestionAnswering
import tensorflow as tf
import numpy as np
from torch.utils.data import DataLoader, Dataset
from fpdf import FPDF
from datetime import datetime
import os
import logging
from typing import Optional, Dict, List, Tuple
import json
from PIL import Image
from tensorflow.keras.preprocessing import image as keras_image
import shutil
from fpdf.enums import XPos, YPos 

class LungCancerQASystem:
    def __init__(self, model_path: str = 'custom_resnet_lung_model.keras', 
                 knowledge_path: str = 'medical_knowledge.json',
                 output_path: str = 'qa_outputs',
                 log_level: int = logging.INFO):
        """
        Initialize the Lung Cancer QA System with improved context handling
        """
        self.setup_logging(log_level)
        self.base_output_path = output_path
        os.makedirs(output_path, exist_ok=True)
        
        # Initialize components
        self.initialize_medical_knowledge(knowledge_path)
        self.setup_models(model_path)
        self.setup_bert()
        
        # Initialize statistics
        self.stats = {
            'classifications': 0,
            'questions_answered': 0,
            'reports_generated': 0,
            'image_classifications': 0
        }
        
        # Track current case directory
        self.current_case_dir = None

    def setup_bert(self):
        """Initialize BERT model with error handling"""
        try:
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            self.qa_model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
            self.logger.info("BERT QA model initialized successfully")
        except Exception as e:
            self.logger.error(f"Error initializing BERT QA model: {e}")
            raise RuntimeError("Failed to initialize BERT model")

    def setup_logging(self, log_level: int) -> None:
        """Set up logging configuration"""
        logging.basicConfig(
            level=log_level,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler('lung_cancer_qa.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)

    def initialize_medical_knowledge(self, knowledge_path: str) -> None:
        """Initialize medical knowledge with improved error handling"""
        try:
            with open(knowledge_path, 'r') as f:
                data = json.load(f)
                self.medical_data = data
                self.contexts = " ".join(data.get('contexts', []))
                self.conditions = data.get('conditions', {})
            self.logger.info("Medical knowledge loaded successfully")
        except Exception as e:
            self.logger.error(f"Error loading medical knowledge: {e}")
            raise RuntimeError("Failed to load medical knowledge")

    def setup_models(self, model_path: str) -> None:
        """Set up CNN model with improved error handling"""
        try:
            self.cnn_model = tf.keras.models.load_model(model_path)
            with open('class_indices.json', 'r') as f:
                self.class_indices = json.load(f)
            self.class_map = {v: k for k, v in self.class_indices.items()}
            self.logger.info("CNN model loaded successfully")
        except Exception as e:
            self.logger.error(f"Error loading CNN model: {e}")
            raise RuntimeError("Failed to load CNN model")

    def create_case_directory(self, image_path: str) -> str:
        """Create a new directory for the current case"""
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        image_name = os.path.splitext(os.path.basename(image_path))[0]
        case_dir = os.path.join(self.base_output_path, f"{timestamp}_{image_name}")
        os.makedirs(case_dir, exist_ok=True)
        
        # Copy the input image to the case directory
        image_copy_path = os.path.join(case_dir, os.path.basename(image_path))
        shutil.copy2(image_path, image_copy_path)
        
        self.current_case_dir = case_dir
        return case_dir

    def save_qa_output(self, data: Dict) -> str:
        """Save QA output to JSON in the current case directory"""
        if not self.current_case_dir:
            raise RuntimeError("No active case directory")
            
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        filename = f"qa_output_{timestamp}.json"
        filepath = os.path.join(self.current_case_dir, filename)
        
        try:
            with open(filepath, 'w') as f:
                json.dump(data, f, indent=4)
            self.logger.info(f"QA output saved to {filepath}")
            return filepath
        except Exception as e:
            self.logger.error(f"Error saving QA output: {e}")
            raise

    def save_classification_output(self, data: Dict) -> str:
        """Save classification output to JSON in the current case directory"""
        if not self.current_case_dir:
            raise RuntimeError("No active case directory")
            
        filename = "classification_output.json"
        filepath = os.path.join(self.current_case_dir, filename)
        
        try:
            with open(filepath, 'w') as f:
                json.dump(data, f, indent=4)
            self.logger.info(f"Classification output saved to {filepath}")
            return filepath
        except Exception as e:
            self.logger.error(f"Error saving classification output: {e}")
            raise

    def generate_report(self, data: Dict) -> str:
    
        if not self.current_case_dir:
            raise RuntimeError("No active case directory")
            
        try:
            pdf = FPDF()
            # Use A4 format (default)
            pdf.add_page()
            
            # Set default font
            pdf.set_font("Helvetica", size=12)
            
            # Add content with simple formatting
            # Title
            pdf.set_font("Helvetica", 'B', 16)
            pdf.cell(190, 10, 'Lung Cancer Analysis Report', ln=True)
            
            # Reset font
            pdf.set_font("Helvetica", size=12)
            pdf.cell(190, 10, f"Generated: {data['timestamp']}", ln=True)
            
            # Add empty line
            pdf.ln(10)
            
            # Classification Results
            pdf.set_font("Helvetica", 'B', 14)
            pdf.cell(190, 10, 'Classification Results:', ln=True)
            pdf.set_font("Helvetica", size=12)
            pdf.cell(190, 10, f"Classification: {data['classification']}", ln=True)
            pdf.cell(190, 10, f"Confidence: {data['confidence']:.2%}", ln=True)
            
            # Add details if available
            if "details" in data:
                pdf.ln(10)
                pdf.set_font("Helvetica", 'B', 14)
                pdf.cell(190, 10, 'Details:', ln=True)
                pdf.set_font("Helvetica", size=12)
                
                details = data["details"]
                if "info" in details:
                    pdf.multi_cell(190, 10, f"Medical Information: {details['info']}")
                
                if "recommendations" in details:
                    pdf.ln(5)
                    pdf.cell(190, 10, "Recommendations:", ln=True)
                    for rec in details["recommendations"]:
                        pdf.multi_cell(190, 10, f"- {rec}")
            
            # Save the PDF
            filepath = os.path.join(self.current_case_dir, "clinical_report.pdf")
            pdf.output(filepath)
            
            self.stats['reports_generated'] += 1
            self.logger.info(f"Report generated: {filepath}")
            return filepath
            
        except Exception as e:
            self.logger.error(f"Error generating report: {e}")
            raise

    def answer_question(self, question: str) -> Dict:
        """Enhanced question answering with current case context"""
        try:
            question = question.lower().strip()
            
            # First check exact matches in predefined Q&A pairs
            for q, a in zip(self.medical_data['questions'], self.medical_data['answers']):
                if question == q.lower().strip():
                    response = {
                        "question": question,
                        "answer": a,
                        "confidence": 1.0,
                        "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                    }
                    
                    if self.current_case_dir:
                        self.save_qa_output(response)
                    
                    self.stats['questions_answered'] += 1
                    return response
            
            # If no exact match, use BERT for contextual search
            best_answer = None
            best_confidence = float('-inf')
            
            # Include current case context if available
            contexts = list(self.medical_data['contexts'])
            if self.current_case_dir:
                try:
                    with open(os.path.join(self.current_case_dir, "classification_output.json")) as f:
                        case_data = json.load(f)
                        if "details" in case_data:
                            case_context = json.dumps(case_data["details"])
                            contexts.append(case_context)
                except Exception:
                    pass
            
            for context in contexts:
                inputs = self.tokenizer.encode_plus(
                    question,
                    context,
                    return_tensors='pt',
                    max_length=512,
                    truncation=True,
                    padding='max_length'
                )
                
                outputs = self.qa_model(**inputs)
                start_scores = outputs.start_logits[0]
                end_scores = outputs.end_logits[0]
                
                start_idx = torch.argmax(start_scores)
                end_idx = torch.argmax(end_scores[start_idx:]) + start_idx
                
                confidence = float(start_scores[start_idx] + end_scores[end_idx])
                
                if confidence > best_confidence:
                    answer_tokens = inputs['input_ids'][0][start_idx:end_idx + 1]
                    answer_text = self.tokenizer.decode(answer_tokens, skip_special_tokens=True)
                    
                    answer_text = answer_text.strip()
                    if answer_text and not answer_text.isspace():
                        best_answer = answer_text
                        best_confidence = confidence
            
            if not best_answer or best_confidence < 0:
                relevant_context = max(contexts,
                                    key=lambda x: len(set(question.split()) & set(x.lower().split())))
                best_answer = relevant_context.split('.')[0] + '.'
                best_confidence = 0.5
            
            response = {
                "question": question,
                "answer": best_answer,
                "confidence": best_confidence,
                "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            }
            
            if self.current_case_dir:
                self.save_qa_output(response)
            
            self.stats['questions_answered'] += 1
            return response
            
        except Exception as e:
            self.logger.error(f"Error in question answering: {e}")
            raise

    def classify_image(self, image_path: str) -> Dict:
        """Enhanced image classification with case directory organization"""
        try:
            # Create new case directory
            self.create_case_directory(image_path)
            
            img = keras_image.load_img(image_path, target_size=(128, 128))
            img_array = keras_image.img_to_array(img) / 255.0
            img_array = np.expand_dims(img_array, axis=0)
            
            predictions = self.cnn_model.predict(img_array)
            class_idx = np.argmax(predictions, axis=1)[0]
            classification = self.class_map[class_idx]
            
            response = {
                "classification": classification,
                "confidence": float(predictions[0][class_idx]),
                "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                "image_path": image_path,
                "predictions": {
                    class_name: float(prob) 
                    for class_name, prob in zip(self.class_indices.keys(), predictions[0])
                }
            }
            
            if classification in self.conditions:
                response["details"] = self.conditions[classification]
            
            self.save_classification_output(response)
            self.generate_report(response)
            self.stats['image_classifications'] += 1
            
            return response
            
        except Exception as e:
            self.logger.error(f"Error in image classification: {e}")
            raise

    def run(self):
        print("\nWelcome to the Enhanced Lung Cancer QA System")
        print("============================================")
    
        while True:
            try:
                print("\nOptions:")
                print("1. Analyze lung image")
                print("2. Ask medical question")
                print("3. View system statistics")
                print("4. Exit")
                
                choice = input("\nSelect option (1-4): ").strip()
                
                if choice == '1':
                    image_path = input("\nEnter image path: ").strip()
                    if not os.path.exists(image_path):
                        print("Error: Image file not found!")
                        continue
                        
                    result = self.classify_image(image_path)
                    print(f"\nClassification: {result['classification']}")
                    print(f"Confidence: {result['confidence']:.2%}")
                    print(f"\nCase directory created: {self.current_case_dir}")
                    
                    if "details" in result:
                        print("\nDetails:")
                        for key, value in result["details"].items():
                            if isinstance(value, str):
                                print(f"{key.capitalize()}: {value}")
                            elif isinstance(value, list):
                                print(f"\n{key.capitalize()}:")
                                for item in value:
                                    print(f"- {item}")
                
                elif choice == '2':
                    question = input("\nEnter your medical question: ").strip()
                    result = self.answer_question(question)
                    print(f"\nAnswer: {result['answer']}")
                    print(f"Confidence: {result['confidence']:.2f}")
                
                elif choice == '3':
                    print("\nSystem Statistics:")
                    for key, value in self.stats.items():
                        print(f"{key.replace('_', ' ').title()}: {value}")
                
                elif choice == '4':
                    print("\nThank you for using the Enhanced Lung Cancer QA System!")
                    break
                
                else:
                    print("\nInvalid choice! Please select 1-4.")
                    
            except Exception as e:
                self.logger.error(f"Error in main loop: {e}")
                print(f"\nAn error occurred: {str(e)}")
                print("Please try again.")

if __name__ == "__main__":
    try:
        system = LungCancerQASystem()
        system.run()
    except Exception as e:
        print(f"System initialization failed: {str(e)}")
        print("Please check logs for details.")

In [None]:
from pathlib import Path
from typing import Dict
from fpdf import FPDF
from PIL import Image
import logging
import json

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ProcessingResult:
    def __init__(self, success: bool, message: str, data: dict = None):
        self.success = success
        self.message = message
        self.data = data if data is not None else {}

class ReportGenerator:
    def __init__(self, user_info):
        self.user_info = user_info

    def extract_classification_descriptions(self, content: str) -> Dict[str, str]:
        """Extract classification descriptions from markdown content."""
        classifications = {}
        current_classification = None

        for line in content.splitlines():
            if line.startswith("## "):  # New classification header
                current_classification = line[3:].strip()  # Capture the header text
                classifications[current_classification] = ""
            elif "Predicted Class:" in line:  # Look for the predicted class line
                continue  # Skip this line
            elif current_classification and line.strip():  # Append description to current classification
                classifications[current_classification] += line + "\n"

        return classifications

    def find_latest_subfolder(self, base_folder: Path) -> Path:
        """Find the most recent subfolder in the qa_outputs directory."""
        subfolders = [f for f in base_folder.iterdir() if f.is_dir()]
        if not subfolders:
            raise ValueError("No subfolders found in qa_outputs directory")
        
        # Sort by folder name (which contains timestamp) in descending order
        latest_folder = sorted(subfolders, key=lambda x: x.name, reverse=True)[0]
        return latest_folder

    def find_image_in_folder(self, folder: Path) -> Path:
        """Find the first image file in the specified folder."""
        image_extensions = ('.jpg', '.jpeg', '.png')
        for ext in image_extensions:
            image_files = list(folder.glob(f"*{ext}"))
            if image_files:
                return image_files[0]
        return None

    def create_pdf_report(self, base_folder: Path, classification: str, markdown_content: Dict[str, str]) -> ProcessingResult:
        """Create the final PDF report with classification and related description."""
        try:
            # Find the latest subfolder
            latest_subfolder = self.find_latest_subfolder(base_folder)
            logger.info(f"Using folder: {latest_subfolder}")

            # Find image in the subfolder
            image_path = self.find_image_in_folder(latest_subfolder)
            if not image_path:
                logger.warning("No image file found in the folder")

            pdf = FPDF()
            pdf.add_page()

            # Set margins
            pdf.set_margins(10, 10, 10)

            # Add Image if found
            if image_path:
                try:
                    img = Image.open(image_path)
                    img.verify()  # Verify image is not corrupted
                    pdf.image(str(image_path), x=10, y=10, w=60)  # Add the image to the PDF
                    logger.info(f"Successfully added image: {image_path}")
                except Exception as img_error:
                    logger.warning(f"Image loading error: {str(img_error)} - Skipping image.")

            # Add Title and User Info
            pdf.set_xy(10, 80)
            pdf.set_font("Arial", "B", 16)
            pdf.cell(200, 10, "Clinical Report", ln=True, align="C")

            pdf.set_font("Arial", "", 10)
            pdf.cell(200, 10, f"Generated for: {self.user_info['name']} (ID: {self.user_info['id']})", ln=True, align="C")
            pdf.cell(200, 10, "Disclaimer: This report is generated by an AI system and should not be used as a replacement for professional medical advice.", ln=True, align="C")

            # Add classification and its description
            if classification in markdown_content:
                pdf.set_font("Arial", "B", 14)
                pdf.ln(10)
                pdf.cell(0, 10, f"Classification: {classification}", ln=True)
                pdf.set_font("Arial", "", 12)
                pdf.ln(5)
                pdf.multi_cell(0, 10, markdown_content[classification].strip())
            else:
                pdf.cell(0, 10, "No description available for the predicted classification.", ln=True)

            # Save PDF in the same subfolder
            output_path = latest_subfolder / "clinical_report.pdf"
            pdf.output(str(output_path))

            return ProcessingResult(
                success=True,
                message=f"Report generated successfully at {output_path}",
                data={"output_path": output_path}
            )
        except Exception as e:
            logger.error(f"Error creating PDF report: {str(e)}")
            return ProcessingResult(
                success=False,
                message=f"Error creating PDF report: {str(e)}"
            )

# Example usage
if __name__ == "__main__":
    user_info = {"name": "John Doe", "id": "12345"}
    report_generator = ReportGenerator(user_info)

    # Read the markdown content from a .md file
    markdown_file_path = Path("cancer-classifications.md")  # Adjust this path accordingly
    if markdown_file_path.is_file():
        with open(markdown_file_path, 'r') as file:
            markdown_content = file.read()
    else:
        print("Markdown file does not exist.")
        exit()

    # Extract classification descriptions from markdown
    classifications = report_generator.extract_classification_descriptions(markdown_content)

    # Simulated predicted classification
    predicted_classification = "Adenocarcinoma (ACA)"  # Use the actual classification header from your file

    # Generate PDF report
    base_folder = Path("qa_outputs")  # This should be the base folder containing timestamped subfolders
    report_result = report_generator.create_pdf_report(base_folder, predicted_classification, classifications)

    print(report_result.message)