In [1]:
import os
import time
import base64
import json
import torch
import re
import numpy as np
import openai
from PIL import Image
from io import BytesIO
from datetime import datetime
from flask import Flask, request, render_template, flash, redirect
from werkzeug.utils import secure_filename
from sklearn.metrics.pairwise import cosine_similarity
from transformers import CLIPProcessor, CLIPModel
from neo4j import GraphDatabase
from langchain.graphs import Neo4jGraph
from langchain.chat_models import ChatOpenAI


In [None]:
# Flask config
app = Flask(__name__)
app.secret_key = "super_secret"  

# API and DB credentials 
openai.api_key = "" 
NEO4J_URL = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "hawa4545"

# Setup models and services
graph = Neo4jGraph(url=NEO4J_URL, username=NEO4J_USERNAME, password=NEO4J_PASSWORD)
llm = ChatOpenAI(openai_api_key=openai.api_key, temperature=0, model_name="gpt-4.1")
neo4j_driver = GraphDatabase.driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32",use_fast=True)


  graph = Neo4jGraph(url=NEO4J_URL, username=NEO4J_USERNAME, password=NEO4J_PASSWORD)
  llm = ChatOpenAI(openai_api_key=openai.api_key, temperature=0, model_name="gpt-4.1")


In [3]:
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'tiff'}

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS


In [4]:
def generate_image_embedding_from_base64(base64_image):
    try:
        image_bytes = base64.b64decode(base64_image)
        image = Image.open(BytesIO(image_bytes)).convert("RGB")
        inputs = clip_processor(images=image, return_tensors="pt").to(device)
        with torch.no_grad():
            image_features = clip_model.get_image_features(**inputs)
        return image_features.cpu().numpy().flatten()
    except Exception as e:
        print(f"Error processing base64 image: {e}")
        return np.random.rand(512)


In [5]:
def get_all_node_embeddings(tx):
    query = """
    MATCH (ot:OntologyTerm)
    WHERE ot.node2Vec IS NOT NULL AND ot.id IS NOT NULL
    RETURN ot.id AS id, ot.node2Vec AS embedding
    """
    result = tx.run(query)
    embeddings = {}
    for record in result:
        embeddings[record["id"]] = np.array(record["embedding"])
    return embeddings

def find_similar_nodes(target_embedding, all_embeddings, top_n=20):
    similarities = {}
    for node_id, embedding in all_embeddings.items():
        similarity = cosine_similarity(
            target_embedding.reshape(1, -1), embedding.reshape(1, -1)
        )[0][0]
        similarities[node_id] = similarity
    return sorted(similarities.items(), key=lambda item: item[1], reverse=True)[:top_n]

def get_node_info_by_ids(tx, node_ids):
    query = """
    MATCH (ot:OntologyTerm)
    WHERE ot.id IN $node_ids
    RETURN ot.id AS id
    """
    result = tx.run(query, node_ids=node_ids)
    return [record["id"] for record in result]


In [6]:
def fetch_similar_terms_from_neo4j(image_embedding):
    with neo4j_driver.session() as session:
        all_embeddings = session.execute_read(get_all_node_embeddings)
        similar_nodes = find_similar_nodes(image_embedding, all_embeddings)
        similar_node_ids = [node_id for node_id, _ in similar_nodes]
        return session.execute_read(get_node_info_by_ids, similar_node_ids)


In [7]:
def encode_image_to_base64(image_path):
    try:
        with open(image_path, "rb") as f:
            return base64.b64encode(f.read()).decode("utf-8")
    except Exception as e:
        print(f"Encoding error: {e}")
        return None


In [8]:
few_shot_messages_base64 = [
     {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Mild Dementia/OAS1_0028_MR1_mpr-1_100.jpg')}"}}]},
    {"role": "assistant", "content": "Mild Dementia"},
    {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Moderate Dementia/OAS1_0308_MR1_mpr-1_100.jpg')}"}}]},
    {"role": "assistant", "content": "Moderate Dementia"},
    {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Very Mild Dementia/OAS1_0003_MR1_mpr-1_100.jpg')}"}}]},
    {"role": "assistant", "content": "Very Mild Dementia"},
    {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Non Demented/OAS1_0112_MR1_mpr-2_149.jpg')}"}}]},
    {"role": "assistant", "content": "Non Demented"},
    {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Very Mild Dementia/OAS1_0003_MR1_mpr-1_113.jpg')}"}}]},
    {"role": "assistant", "content": "Very Mild Dementia"},
    {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Mild Dementia/OAS1_0028_MR1_mpr-1_112.jpg')}"}}]},
    {"role": "assistant", "content": "Mild Dementia"},
    {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Very Mild Dementia/OAS1_0003_MR1_mpr-1_116.jpg')}"}}]},
    {"role": "assistant", "content": "Very Mild Dementia"},
    {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Mild Dementia/OAS1_0028_MR1_mpr-1_123.jpg')}"}}]},
    {"role": "assistant", "content": "Mild Dementia"},
    {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Non Demented/OAS1_0112_MR1_mpr-3_107.jpg')}"}}]},
    {"role": "assistant", "content": "Non Demented"},
    {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Non Demented/OAS1_0112_MR1_mpr-3_111.jpg')}"}}]},
    {"role": "assistant", "content": "Non Demented"},
    {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Moderate Dementia/OAS1_0308_MR1_mpr-1_118.jpg')}"}}]},
    {"role": "assistant", "content": "Moderate Dementia"},
    {"role": "user", "content": [{"type": "text", "text": "Classify this MRI scan."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image_to_base64(r'D:/project/Data/Moderate Dementia/OAS1_0308_MR1_mpr-1_123.jpg')}"}}]},
]

In [9]:
schema_prompt_with_embeddings = """
You are a medical AI assistant. Analyze the MRI image and consider its features.
Compare the image's features with concepts in the knowledge graph, using embeddings to find similar concepts.
Use the IDs of these similar concepts from the graph, along with the image and its features and few-shot examples,
to make an accurate diagnosis. Reply with the category: Moderate Dementia, Mild Dementia, Very Mild Dementia, Non Demented.
Also use other concepts in knowledge graph like Diagnosis stage etc. 
Here is the schema:
Node Types:
MRI_Image:
Properties: image_path, white_matter_lesions, cortical_thickness, ventricle_size, diagnosis
Relationships: [:HAS_DIAGNOSIS] -> DiagnosisStage
DiagnosisStage:
Properties: name, avg_white_matter_lesions, avg_cortical_thickness, avg_ventricle_size
OntologyTerm:
Properties: id, name, description, node2vec (embedding vector)
Relationships: [:RELATIONSHIP {type: "is_a" | "part_of"}] -> OntologyTerm
"""


In [10]:
def parse_professional_report(report_text):
    sections = {
        "technique": "",
        "findings": "",
        "impression": "",
        "diagnosis": "",
        "recommendations": ""
    }
    
    try:
        # Clean up the report text
        report_text = report_text.replace("**", "").replace("\n\n", "\n").strip()
        
        # Extract Technique
        technique_match = re.search(r"TECHNIQUE:(.+?)(?=FINDINGS:|\Z)", report_text, re.DOTALL)
        if technique_match:
            sections["technique"] = technique_match.group(1).strip()
        else:
            sections["technique"] = "Multiplanar MRI sequences were obtained including T1-weighted, T2-weighted, and FLAIR images."
        
        # Extract Findings
        findings_match = re.search(r"FINDINGS:(.+?)(?=IMPRESSION:|\Z)", report_text, re.DOTALL)
        if findings_match:
            findings_content = findings_match.group(1).strip()
            # Clean up findings bullets
            findings_content = findings_content.replace("* ", "• ")
            sections["findings"] = findings_content
        
        # Extract Impression
        impression_match = re.search(r"IMPRESSION:(.+?)(?=DIAGNOSIS:|RECOMMENDATIONS:|\Z)", report_text, re.DOTALL)
        if impression_match:
            impression_content = impression_match.group(1).strip()
            # Number the impression points properly
            impression_content = re.sub(r"(\d+\.)\s+", r"\1 ", impression_content)
            sections["impression"] = impression_content
        
        # Enhanced Diagnosis extraction
        diagnosis = ""
        # First try to extract from DIAGNOSIS section
        diagnosis_match = re.search(r"DIAGNOSIS:(.+?)(?=RECOMMENDATIONS:|\Z)", report_text, re.DOTALL)
        if diagnosis_match:
            diagnosis = diagnosis_match.group(1).strip()
        else:
            # Then try to extract from IMPRESSION section
            impression = sections.get("impression", "")
            diagnosis_match = re.search(
                r"Primary diagnosis[:\-]?\s*(.+?)(?:\s*(?:based on|due to|with|,)|$)",
                impression,
                re.IGNORECASE
            )
            if diagnosis_match:
                diagnosis = diagnosis_match.group(1).strip()
                # Clean up diagnosis text
                diagnosis = re.sub(r"^\W+|\W+$", "", diagnosis)  # Remove surrounding punctuation
                diagnosis = diagnosis.split(",")[0]  # Take only first part if comma separated
        
        # Set the diagnosis, defaulting to "Diagnosis not specified" if empty
        sections["diagnosis"] = diagnosis if diagnosis else "Diagnosis not specified"
        
        # Extract Recommendations
        recommendations_match = re.search(r"RECOMMENDATIONS:(.+)", report_text, re.DOTALL)
        if recommendations_match:
            recommendations_content = recommendations_match.group(1).strip()
            # Number the recommendations properly
            recommendations_content = re.sub(r"(\d+\.)\s+", r"\1 ", recommendations_content)
            sections["recommendations"] = recommendations_content
    
    except Exception as e:
        print("Parsing error:", e)
    
    return sections

In [11]:
@app.route('/', methods=['GET', 'POST'])
def index():
    if request.method == 'POST':
        if 'file' not in request.files:
            flash("No file part")
            return redirect(request.url)

        file = request.files['file']
        if file.filename == '':
            flash("No selected file")
            return redirect(request.url)

        if file and allowed_file(file.filename):
            try:
                # Read the uploaded image
                image_bytes = file.read()
                base64_image = base64.b64encode(image_bytes).decode('utf-8')
                start_time = time.time()

                # Generate image embedding and fetch related terms from Neo4j
                image_embedding = generate_image_embedding_from_base64(base64_image)
                similar_terms_info = fetch_similar_terms_from_neo4j(image_embedding)
                embedding_context = (
                    "Related Ontology Term IDs: " + ", ".join(similar_terms_info)
                    if similar_terms_info else "No related concepts found in the knowledge graph."
                )

                # Prompt for GPT-4.1
                prompt_text = (
                    f"Generate a professional radiology report for this MRI brain scan using the image features and "
                    f"graph-based knowledge context provided below.\n\n"
                    f"Knowledge Graph Context (based on image features and node embeddings): {embedding_context}.\n\n"
                    "Follow this exact structure and clinical formatting:\n\n"
                    "MRI BRAIN - DEMENTIA ASSESSMENT REPORT\n\n"
                    "TECHNIQUE:\n"
                    "Multiplanar MRI sequences were obtained including T1-weighted, T2-weighted, and FLAIR images.\n\n"
                    "FINDINGS:\n"
                    "- Brain parenchyma: Describe volume loss, asymmetry, or preservation\n"
                    "- Ventricular system: Note size, symmetry, or any dilation\n"
                    "- Sulci and gyri: Comment on cortical thinning or widening\n"
                    "- White matter: Describe hyperintensities or demyelination\n"
                    "- Basal ganglia and thalamus: Mention any signal abnormalities or structural changes\n"
                    "- Posterior fossa: Describe cerebellum and brainstem appearance\n\n"
                    "IMPRESSION:\n"
                    "1. Primary diagnosis: [Exactly one of: Non Demented, Very Mild Dementia, Mild Dementia, or Moderate Dementia] explain based on key imaging and graph-derived findings\n"
                    "2. Supporting evidence:Specific features observed in the scan and relevant graph-derived contex\n"
                    "3. Differential considerations: Only list if graph features suggest ambiguity.Prioritize by biological plausibility. Include atypical presentations if indicated.Don't reference about graph  just use it and give diafferential consideration.\n\n"
                    "DIAGNOSIS:\n"
                    "[Repeat the exact diagnosis category from the Impression section here]\n\n"
                    "RECOMMENDATIONS:\n"
                    "Clinical correlation with neuropsychological testing\n"
                    "Follow-up imaging or assessment timeline if necessary\n"
                    "Add more reccomendations based on diagnosis atleast 4 to 6\n"
                    "Consider referral to neurology or memory clinic if warranted\n\n"
                    "Important Notes:\n"
                    "- The diagnosis must appear in both IMPRESSION and DIAGNOSIS sections\n"
                    "- Use only the exact diagnosis categories listed above\n"
                    "- Keep language professional and concise\n"
                    "- Do not mention the knowledge graph in the report"
                )

                # Construct prompt for GPT with image and text
                test_image_prompt = {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt_text},
                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
                    ]
                }

                # GPT-4.1 API call
                response = openai.ChatCompletion.create(
                    model="gpt-4.1",
                    messages=[{"role": "system", "content": schema_prompt_with_embeddings}] + few_shot_messages_base64 + [test_image_prompt],
                    max_tokens=1000
                )

                full_response = response['choices'][0]['message']['content'].strip()

                # Parse with the improved professional parser
                report_data = parse_professional_report(full_response)

                # Calculate processing time and get timestamp
                processing_time = round(time.time() - start_time, 2)
                timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

                return render_template(
                    "result.html",
                    report_data=report_data,
                    image_data=base64_image,
                    timestamp=timestamp,
                    processing_time=processing_time
                )

            except Exception as e:
                flash(f"Error processing image: {e}")
                return redirect(request.url)

    return render_template('index.html')


In [12]:
if __name__ == '__main__':
    app.run(port=5001, debug=False, use_reloader=False)

 * Serving Flask app '__main__'


 * Debug mode: off


 * Running on http://127.0.0.1:5001
Press CTRL+C to quit
127.0.0.1 - - [23/May/2025 09:32:52] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [23/May/2025 09:32:52] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [23/May/2025 09:34:51] "POST / HTTP/1.1" 200 -
