# Brain Tumor Diagnosis System Using Vision Transformers: Flask Application

**Andy Achouche**

Capstone Project submitted to the Faculty of the  
Grand Canyon University  
In partial fulfillment of the requirements for the degree of  

**Master of Science in Data Science**

April 30, 2025

---

© 2025 Andy Achouche

In [None]:
"""
Medical Imaging Application
---------------------------
This application demonstrates a multi-step workflow for processing MRI images and generating a medical report.
Steps include:
  - Doctor Login
  - Patient Profile (Search/Add Patient)
  - Medical & Family History Assessment
  - Combined Genetic & Laboratory Testing (enter a value for every test; use "N/A" or leave blank if not tested)
  - Imaging Studies instructions
  - MRI Image Upload and Prediction (using a pre-trained Vision Transformer)
  - Viewing/Downloading a detailed Medical Report
The application includes an inactivity timer (10 minutes) and provides "Back" and "Skip" navigation where appropriate.
"""

import sqlite3
import torch
from torchvision import transforms
from PIL import Image
from vit_pytorch import ViT
import torch.nn.functional as F
from flask import Flask, request, redirect, url_for, session, Response, render_template_string, jsonify
import io, base64
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend for matplotlib
import matplotlib.pyplot as plt
from datetime import timedelta


# ---------------------------- FLASK ------------------------
app = Flask(__name__)
app.secret_key = 'your_secret_key_here' # Replace with a strong secret key
app.url_map.strict_slashes = False
app.permanent_session_lifetime = timedelta(minutes=10)  # session expires in 10 mins


#-------------------------------------------------------
ALLOWED_PATHS = {"/login", "/health", "/static"}  # add others if needed

@app.before_request
def require_login():
    # skip allowed paths
    if request.path.startswith(tuple(ALLOWED_PATHS)):
        return
    # block everything else if not authenticated
    if not session.get("authenticated"):
        return redirect(url_for("login"))


# ------------------- Database Setup -------------------
def get_db_connection():
    conn = sqlite3.connect('patients.db')
    conn.row_factory = sqlite3.Row
    return conn

def init_db():
    conn = get_db_connection()
    conn.execute('''
        CREATE TABLE IF NOT EXISTS patients (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            first_name TEXT NOT NULL,
            last_name TEXT NOT NULL,
            age INTEGER,
            gender TEXT
        )
    ''')
    conn.commit()
    conn.close()

init_db()

# ------------------- Helper for Formatting Responses -------------------
def format_response(category, value):
    if value is None or value.strip() == "" or value.strip().upper() == "N/A":
         return value
    # Medical History: assume "Yes" = abnormal (red), "No" = normal (green)
    if category in ["Head Trauma", "Family History", "Seizures", "Genetic Disorders"]:
         if value.strip().lower() == "yes":
             return f"<span style='color: red;'>{value}</span>"
         elif value.strip().lower() == "no":
             return f"<span style='color: green;'>{value}</span>"
         else:
             return value
    # Genetic Tests
    elif category in ["NF1/NF2", "TP53", "MLH1/MSH2", "VHL"]:
         if value.strip().lower() == "positive":
             return f"<span style='color: red;'>{value}</span>"
         elif value.strip().lower() == "negative":
             return f"<span style='color: green;'>{value}</span>"
         else:
             return value
    # Laboratory tests for NSE and S100
    elif category in ["NSE", "S100"]:
         try:
             num = float(value)
             if category == "NSE":
                 if num >= 12.5:
                     return f"<span style='color: red;'>{value}</span>"
                 else:
                     return f"<span style='color: green;'>{value}</span>"
             elif category == "S100":
                 if num >= 0.105:
                     return f"<span style='color: red;'>{value}</span>"
                 else:
                     return f"<span style='color: green;'>{value}</span>"
         except:
             return value
    elif category == "ctDNA":
         if value.strip().lower() == "positive":
             return f"<span style='color: red;'>{value}</span>"
         elif value.strip().lower() == "negative":
             return f"<span style='color: green;'>{value}</span>"
         else:
             return value
    else:
         return value 

# ------------------- Common HTML Fragments -------------------
exit_button = '''
<form action="/logout" method="get" style="margin-top:20px; display:inline-block;">
    <input type="submit" value="Exit" style="font-size:18px; padding:10px;">
</form>
'''
help_button = '''
<button onclick="toggleHelp()" style="font-size: 18px; padding: 10px; margin-top:20px;">Help</button>
'''
common_js = '''
<script>
  var defaultHelpText = "This application assists doctors in managing patient data including adding new patients, searching for records, and processing medical tests. Follow on-screen instructions.";
  var helpText = (typeof page_help_text !== 'undefined') ? page_help_text : defaultHelpText;
  function toggleHelp() {
      var modal = document.getElementById('helpModal');
      if (modal.style.display === 'block') {
         modal.style.display = 'none';
      } else {
         modal.style.display = 'block';
         modal.innerHTML = helpText + '<br><br><button onclick="toggleHelp()" style="font-size: 18px; padding: 10px;">Close</button>';
      }
  }

  var countdown = 600; // 10 minutes (600 seconds)
  var timer;
  var countdownInterval;

  function updateTimerDisplay() {
      var timerElem = document.getElementById('timer');
      if (timerElem) { timerElem.innerText = countdown; }
  }

  function logoutUser() {
      window.location.href = '/logout';
  }

  function resetTimer() {
      countdown = 600;
      updateTimerDisplay();

      clearTimeout(timer);
      clearInterval(countdownInterval);

      timer = setTimeout(logoutUser, countdown * 1000);

      countdownInterval = setInterval(function(){
          countdown--;
          updateTimerDisplay();
          if (countdown <= 0) clearInterval(countdownInterval);
      }, 1000);
  }

  window.onload = resetTimer;
  document.onmousemove = resetTimer;
  document.onkeypress = resetTimer;
  document.onclick = resetTimer;
</script>


<div id="countdownTimer" style="font-size: 18px; position: fixed; top: 10px; right: 10px; background: #fff; padding: 5px; border: 1px solid #ccc; z-index: 10000;">
  Session expires in: <span id="timer">600</span> seconds
</div>
<div id="helpModal" style="display:none; position: fixed; top: 20%; left: 20%; width: 60%; background-color: #f8f8f8; padding: 20px; border: 1px solid #ccc; z-index: 1000;"></div>
'''

def get_patient_header():
    patient = session.get("patient", "")
    return f'<h2 style="font-size: 20px;">Patient: {patient}</h2>' if patient else ''


# ----------------------------------------

@app.route("/health")
def health():
    return "ok", 200

# ------------------- Doctor Login -------------------
@app.route('/login', methods=['GET', 'POST'])
def login():
    login_help = "Enter your doctor security code to access the system."
    error_message = ""
    if request.method == 'POST':
        if request.form.get('security_code') == '1234':
            session['authenticated'] = True
            session.permanent = True  # <-- Add this clearly
            return redirect(url_for('patient_profile'))
        else:
            error_message = '<p style="color: red; font-size: 20px;">Incorrect code. Please try again.</p>'
    return render_template_string('''
        <div style="font-size: 20px;">
            <h1>Doctor Security Code</h1>
            {{ error|safe }}
            <form method="post">
                <label style="font-size: 18px;">Enter Security Code:</label>
                <input type="password" name="security_code" style="font-size: 18px; padding: 5px;">
                <input type="submit" value="Submit" style="font-size: 18px; padding: 10px;">
            </form>
            {{ help_button|safe }} {{ exit_button|safe }}
        </div>
        <script>var page_help_text = "{{ help_text }}";</script>
        {{ common_js|safe }}
    ''', error=error_message, help_button=help_button, exit_button=exit_button,
         help_text=login_help, common_js=common_js)


# ------------------- LogOut ---------------------------
@app.route("/logout")
def logout():
    session.clear()                 # wipes the “authenticated” flag + everything else
    return redirect(url_for("login"))


# ------------------- Patient Profile -------------------
@app.route('/patient_profile')
def patient_profile():
    if not session.get('authenticated'):
        return redirect(url_for('login'))
    profile_help = "Choose to add a new patient or search for an existing patient record."
    return render_template_string('''
        <h1 style="font-size: 24px;">Patient Profile</h1>
        <p style="font-size: 20px;">Please choose an option:</p>
        <form action="/add_patient" method="get">
            <input type="submit" value="Add New Patient" style="font-size: 18px; padding: 10px;">
        </form>
        <br>
        <form action="/search_patient" method="get">
            <input type="submit" value="Search Existing Patient" style="font-size: 18px; padding: 10px;">
        </form>
        {{ help_button|safe }} {{ exit_button|safe }}
        <script>var page_help_text = "{{ help_text }}";</script>
        {{ common_js|safe }}
    ''', help_button=help_button, exit_button=exit_button, help_text=profile_help, common_js=common_js)

# ------------------- Add New Patient -------------------
@app.route('/add_patient', methods=['GET', 'POST'])
def add_patient():
    if not session.get('authenticated'):
        return redirect(url_for('login'))
    add_help = "Enter the patient's details including first name, last name, age, and gender."
    if request.method == 'POST':
        first_name = request.form.get('first_name')
        last_name = request.form.get('last_name')
        age = request.form.get('age')
        gender = request.form.get('gender')
        session['patient'] = f"{first_name} {last_name}"
        conn = get_db_connection()
        conn.execute('INSERT INTO patients (first_name, last_name, age, gender) VALUES (?, ?, ?, ?)',
                     (first_name, last_name, age, gender))
        conn.commit()
        conn.close()
        return redirect(url_for('medical_history'))
    return render_template_string('''
        <h1 style="font-size: 24px;">Add New Patient</h1>
        <form method="post">
            <label style="font-size: 20px;">First Name:</label>
            <input type="text" name="first_name" required style="font-size: 18px; padding: 5px;"><br><br>
            <label style="font-size: 20px;">Last Name:</label>
            <input type="text" name="last_name" required style="font-size: 18px; padding: 5px;"><br><br>
            <label style="font-size: 20px;">Age:</label>
            <input type="number" name="age" required style="font-size: 18px; padding: 5px;"><br><br>
            <label style="font-size: 20px;">Gender:</label>
            <select name="gender" style="font-size: 18px; padding: 5px;">
                <option value="Male">Male</option>
                <option value="Female">Female</option>
            </select><br><br>
            <input type="submit" value="Submit" style="font-size: 18px; padding: 10px;">
        </form>
        {{ help_button|safe }} {{ exit_button|safe }}
        <script>var page_help_text = "{{ help_text }}";</script>
        {{ common_js|safe }}
    ''', help_button=help_button, exit_button=exit_button, help_text=add_help, common_js=common_js)

# ------------------- Search Existing Patient -------------------
@app.route('/search_patient', methods=['GET', 'POST'])
def search_patient():
    if not session.get('authenticated'):
        return redirect(url_for('login'))
    search_help = "Select a patient from the dropdown or add a new patient if no records exist."
    conn = get_db_connection()
    patients = conn.execute('SELECT * FROM patients').fetchall()
    conn.close()
    options = "".join(f'<option value="{row["first_name"]} {row["last_name"]}">{row["first_name"]} {row["last_name"]}</option>' for row in patients)
    error = ""
    if request.method == 'POST':
        selected = request.form.get('patient_select')
        if not selected:
            error = '<p style="color: red; font-size: 20px;">Patient record not found</p>'
        else:
            session['patient'] = selected
            return redirect(url_for('medical_history'))
    if not patients:
        return render_template_string('''
            <h1 style="font-size: 24px;">Search Existing Patient</h1>
            <p style="font-size: 20px;">No patient records found.</p>
            <form action="/add_patient" method="get">
                <input type="submit" value="Add New Patient" style="font-size: 18px; padding: 10px;">
            </form>
            {{ help_button|safe }} {{ exit_button|safe }}
            <script>var page_help_text = "{{ help_text }}";</script>
            {{ common_js|safe }}
        ''', help_button=help_button, exit_button=exit_button, help_text=search_help, common_js=common_js)
    return render_template_string('''
        <h1 style="font-size: 24px;">Search Existing Patient</h1>
        {{ error|safe }}
        <form method="post">
            <label style="font-size: 20px;">Select Patient:</label>
            <select name="patient_select" style="font-size: 20px; padding: 5px; height: 40px;">
                <option value="">--Select a Patient--</option>
                {{ options|safe }}
            </select><br><br>
            <input type="submit" value="Search" style="font-size: 18px; padding: 10px;">
        </form>
        <br>
        <form action="/add_patient" method="get">
            <input type="submit" value="Add New Patient" style="font-size: 18px; padding: 10px;">
        </form>
        {{ help_button|safe }} {{ exit_button|safe }}
        <script>var page_help_text = "{{ help_text }}";</script>
        {{ common_js|safe }}
    ''', error=error, options=options, help_button=help_button, exit_button=exit_button,
         help_text=search_help, common_js=common_js)

# ------------------- Combined Questionnaire -------------------
@app.route('/medical_history', methods=['GET', 'POST'])
def medical_history():
    if not session.get('authenticated'):
        return redirect(url_for('login'))
    history_help = "Answer the following questions by selecting Yes, No, or N/A if not applicable."
    if request.method == 'POST':
        # Save medical history answers separately.
        session['medical_history'] = {
            "Head Trauma": request.form.get('q1'),
            "Family History": request.form.get('q2'),
            "Seizures": request.form.get('q3'),
            "Genetic Disorders": request.form.get('q4')
        }
        # Then redirect to combined tests.
        return redirect(url_for('combined_tests'))
    return render_template_string('''
        <h1 style="font-size: 24px;">Medical & Family History Assessment</h1>
        {{ header|safe }}
        <p style="font-size: 20px;">Please answer the following questions:</p>
        <form method="post">
            <h2>Medical & Family History</h2>
            <label style="font-size: 20px;">Have you had any prior head trauma?</label>
            <select name="q1" style="font-size: 20px;">
                <option value="">--Select--</option>
                <option value="Yes">Yes</option>
                <option value="No">No</option>
                <option value="N/A">N/A</option>
            </select><br><br>
            <label style="font-size: 20px;">Is there any family history of brain tumors?</label>
            <select name="q2" style="font-size: 20px;">
                <option value="">--Select--</option>
                <option value="Yes">Yes</option>
                <option value="No">No</option>
                <option value="N/A">N/A</option>
            </select><br><br>
            <label style="font-size: 20px;">Have you experienced unexplained seizures or neurological symptoms?</label>
            <select name="q3" style="font-size: 20px;">
                <option value="">--Select--</option>
                <option value="Yes">Yes</option>
                <option value="No">No</option>
                <option value="N/A">N/A</option>
            </select><br><br>
            <label style="font-size: 20px;">Do you have a history of genetic disorders related to tumors?</label>
            <select name="q4" style="font-size: 20px;">
                <option value="">--Select--</option>
                <option value="Yes">Yes</option>
                <option value="No">No</option>
                <option value="N/A">N/A</option>
            </select><br><br>
            <input type="submit" value="Next Step" style="font-size: 18px; padding: 10px;">
        </form>
        {{ help_button|safe }} {{ exit_button|safe }}
        <script>var page_help_text = "{{ help_text }}";</script>
        {{ common_js|safe }}
    ''', header=get_patient_header(), help_button=help_button, exit_button=exit_button,
         help_text=history_help, common_js=common_js)

# ------------------- Combined Genetic & Laboratory Tests -------------------
@app.route('/combined_tests', methods=['GET', 'POST'])
def combined_tests():
    header = get_patient_header()
    tests_help = ("Please enter the results for the following tests. For each test, select or enter a value. "
                  "If not performed, you may leave blank or enter 'N/A'.")
    if request.method == 'POST':
        session['combined_tests'] = {
            # Genetic testing:
            "NF1/NF2": request.form.get('nf_result'),
            "TP53": request.form.get('tp53_result'),
            "MLH1/MSH2": request.form.get('mlh1_msh2_result'),
            "VHL": request.form.get('vhl_result'),
            # Laboratory tests:
            "NSE": request.form.get('nse_result'),
            "S100": request.form.get('s100_result'),
            "ctDNA": request.form.get('ctdna_result')
        }
        return redirect(url_for('combined_results'))
    return render_template_string('''
        <h1 style="font-size: 24px;">Combined Genetic & Laboratory Tests</h1>
        {{ header|safe }}
        <form method="post">
            <h2>Genetic Testing Results</h2>
            <label style="font-size: 20px;">NF1/NF2 Gene Testing:</label>
            <select name="nf_result" style="font-size: 20px;">
                <option value="">--Select--</option>
                <option value="negative">Negative</option>
                <option value="positive">Positive</option>
                <option value="N/A">N/A</option>
            </select><br><br>
            <label style="font-size: 20px;">TP53 Gene Mutation:</label>
            <select name="tp53_result" style="font-size: 20px;">
                <option value="">--Select--</option>
                <option value="negative">Negative</option>
                <option value="positive">Positive</option>
                <option value="N/A">N/A</option>
            </select><br><br>
            <label style="font-size: 20px;">MLH1/MSH2 Gene Testing:</label>
            <select name="mlh1_msh2_result" style="font-size: 20px;">
                <option value="">--Select--</option>
                <option value="negative">Negative</option>
                <option value="positive">Positive</option>
                <option value="N/A">N/A</option>
            </select><br><br>
            <label style="font-size: 20px;">VHL Gene Testing:</label>
            <select name="vhl_result" style="font-size: 20px;">
                <option value="">--Select--</option>
                <option value="negative">Negative</option>
                <option value="positive">Positive</option>
                <option value="N/A">N/A</option>
            </select><br><br>
            <h2>Laboratory Test Results</h2>
            <label style="font-size: 20px;">NSE (Normal &lt; 12.5 ng/mL):</label>
            <input type="text" name="nse_result" value="N/A" style="font-size: 18px; padding: 5px;"><br><br>
            <label style="font-size: 20px;">S100 (Normal &lt; 0.105 µg/L):</label>
            <input type="text" name="s100_result" value="N/A" style="font-size: 18px; padding: 5px;"><br><br>
            <label style="font-size: 20px;">ctDNA:</label>
            <select name="ctdna_result" style="font-size: 20px;">
                <option value="">--Select--</option>
                <option value="negative">Negative</option>
                <option value="positive">Positive</option>
                <option value="N/A">N/A</option>
            </select><br><br>
            <input type="submit" value="Next Step" style="font-size: 18px; padding: 10px;">
        </form>
        <div style="margin-top:10px;">
          <form action="/" method="get" style="display:inline-block; margin-right:10px;">
              <input type="submit" value="Back" style="font-size: 18px; padding: 10px 20px;">
          </form>
          <form action="/imaging_studies" method="get" style="display:inline-block;">
              <input type="submit" value="Skip" style="font-size: 18px; padding: 10px 20px;">
          </form>
        </div>
        {{ help_button|safe }} {{ exit_button|safe }}
        <script>var page_help_text = "{{ help_text }}";</script>
        {{ common_js|safe }}
    ''', header=get_patient_header(), help_button=help_button, exit_button=exit_button,
         help_text=tests_help, common_js=common_js)

# ------------------- Combined Tests Results (Display with Coloring) -------------------
@app.route('/combined_results')
def combined_results():
    header = get_patient_header()
    combined = session.get('combined_tests', {})
    # Retrieve medical history from session as well.
    med_history = session.get('medical_history', {})
    # We'll build an HTML display showing each section with formatted responses.
    display_sections = []
    
    # Medical & Family History section
    if med_history:
        med_html = "<h2>Medical & Family History</h2>"
        for key, value in med_history.items():
            formatted = format_response(key, value)
            med_html += f"<p>{key}: {formatted}</p>"
        display_sections.append(med_html)
    
    # Genetic Testing section
    if combined:
        gen_keys = ["NF1/NF2", "TP53", "MLH1/MSH2", "VHL"]
        lab_keys = ["NSE", "S100", "ctDNA"]
        if any(key in combined for key in gen_keys):
            gen_html = "<h2>Genetic Testing</h2>"
            for key in gen_keys:
                val = combined.get(key, "")
                formatted = format_response(key, val)
                gen_html += f"<p>{key}: {formatted}</p>"
            display_sections.append(gen_html)
        if any(key in combined for key in lab_keys):
            lab_html = "<h2>Laboratory Test Results</h2>"
            for key in lab_keys:
                val = combined.get(key, "")
                formatted = format_response(key, val)
                lab_html += f"<p>{key}: {formatted}</p>"
            display_sections.append(lab_html)
    
    combined_html = "".join(display_sections)
    
    # Provide a button to continue to imaging studies.
    next_button = '''
    <form action="/imaging_studies" method="get" style="display:inline-block;">
        <input type="submit" value="Continue to Imaging Studies" style="font-size: 18px; padding: 10px 20px;">
    </form>
    '''
    
    return render_template_string('''
        <h1 style="font-size: 24px;">Combined Test Results</h1>
        {{ header|safe }}
        <div style="font-size: 20px;">
            {{ combined_html|safe }}
        </div>
        <br>
        {{ next_button|safe }}
        {{ help_button|safe }} {{ exit_button|safe }}
        <script>var page_help_text = "Review your combined test results. Abnormal responses are shown in red and normal responses in green. Blank or N/A values are unformatted."; </script>
        {{ common_js|safe }}
    ''', header=header, combined_html=combined_html, next_button=next_button,
         help_button=help_button, exit_button=exit_button, common_js=common_js)

# ------------------- Imaging Studies -------------------
@app.route('/imaging_studies', methods=['GET', 'POST'])
def imaging_studies():
    header = get_patient_header()
    imaging_help = "Follow the instructions for MRI scanning. Once you have the MRI results, proceed to the upload page."
    if request.method == 'GET':
        return render_template_string('''
            <h1 style="font-size: 24px;">Imaging Studies</h1>
            {{ header|safe }}
            <p style="font-size: 20px;">Based on previous results, an MRI of the brain is recommended.</p>
            <p style="font-size: 20px;">Please proceed with MRI scanning. Once you have the MRI results, you can continue to the next step.</p>
            <form action="/upload_mri" method="get">
                <input type="submit" value="Proceed to MRI Upload" style="font-size: 18px; padding: 10px;">
            </form>
            {{ help_button|safe }} {{ exit_button|safe }}
            <script>var page_help_text = "{{ help_text }}";</script>
            {{ common_js|safe }}
        ''', header=header, help_button=help_button, exit_button=exit_button, help_text=imaging_help, common_js=common_js)
    else:
        return render_template_string('''
            <h1 style="font-size: 24px;">Imaging Studies</h1>
            {{ header|safe }}
            <p style="font-size: 20px;">Please proceed to MRI upload.</p>
            <form action="/upload_mri" method="get">
                <input type="submit" value="Proceed to MRI Upload" style="font-size: 18px; padding: 10px;">
            </form>
            {{ help_button|safe }} {{ exit_button|safe }}
            <script>var page_help_text = "{{ help_text }}";</script>
            {{ common_js|safe }}
        ''', header=header, help_button=help_button, exit_button=exit_button, help_text=imaging_help, common_js=common_js)

# ------------------- MRI Image Upload & Prediction Result -------------------
result_template = '''
    <h1 style="font-size: 24px;">MRI Classification Result</h1>
    {{ patient_header|safe }}
    <h2 style="font-size: 22px;">Prediction (Histogram):</h2>
    <div style="text-align: left;">
        <img src="data:image/png;base64,{{ graph }}" alt="MRI Histogram" style="max-width:500px; display: block;">
    </div>
    <h3 style="font-size: 22px;">Predicted Diagnosis: {{ predicted_class }}</h3>
    <h3 style="font-size: 22px;">Confidence: {{ predicted_probability }}%</h3>
    <p style="font-size: 18px;">
        Note: This prediction uses a pre-trained Vision Transformer (ViT) model—a neural network architecture that leverages self-attention mechanisms to capture global image features. Vision Transformers have demonstrated high precision in image analysis and classification.
    </p>
    <form action="/upload_mri" method="get" style="display:inline-block;">
        <input type="submit" value="Upload Another MRI" style="font-size: 20px; padding: 10px 20px;">
    </form>
    <form action="/patient_profile" method="get" style="display:inline-block; margin-left:10px;">
        <input type="submit" value="Start Over (New Patient)" style="font-size: 20px; padding: 10px 20px;">
    </form>
    <form action="/login" method="get" style="display:inline-block; margin-left:10px;">
        <input type="submit" value="Exit" style="font-size: 20px; padding: 10px 20px;">
    </form>
    <br><br>
    <form action="/generate_report" method="get" style="display:inline-block;">
        <input type="submit" value="Generate Medical Report" style="font-size: 20px; padding: 10px 20px;">
    </form>
    {{ help_button|safe }}
    <script>var page_help_text = "This page shows the MRI classification results as a histogram. Click the button below to generate a comprehensive medical report."; </script>
    {{ common_js|safe }}
'''

# ------------------- /upload_mri -------------------
@app.route('/upload_mri')
def upload_mri():
    header = get_patient_header()
    upload_help = "Upload an MRI image (PNG format). After uploading, the system will analyze and classify the image.";
    return render_template_string('''
        <h1 style="font-size: 24px;">Upload MRI Image</h1>
        {{ header|safe }}
        <form method="POST" action="/predict" enctype="multipart/form-data">
            <input type="file" name="file" style="font-size: 20px; padding: 10px 20px;">
            <input type="submit" value="Analyze and Classify" style="font-size: 20px; padding: 10px 20px;">
            <input type="reset" value="Reset" style="font-size: 20px; padding: 10px 20px;">
        </form>
        {{ help_button|safe }} {{ exit_button|safe }}
        <script>var page_help_text = "{{ help_text }}";</script>
        {{ common_js|safe }}
    ''', header=header, help_button=help_button, exit_button=exit_button, help_text=upload_help, common_js=common_js)

# ------------------- Model Loading and Prediction -------------------
model = ViT(
    image_size=224,
    patch_size=16,
    num_classes=4,
    dim=512,
    depth=6,
    heads=8,
    mlp_dim=1024,
    dropout=0.1,
    emb_dropout=0.1
)
model.load_state_dict(torch.load('vit_mri_model.pth', map_location=torch.device('cpu')))
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# ------------------- /predict -------------------
@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files or request.files['file'].filename == '':
        return jsonify({'error': 'No file part or no selected file'}), 400
    file = request.files['file']
    img = Image.open(file).convert('RGB')
    img_transformed = transform(img).unsqueeze(0)
    buf = io.BytesIO()
    img.save(buf, 'PNG')
    buf.seek(0)
    img_data = base64.b64encode(buf.getvalue()).decode()
    with torch.no_grad():
        outputs = model(img_transformed)
        probabilities = F.softmax(outputs, dim=1)
        predicted = torch.argmax(probabilities, dim=1)
        predicted_probability = probabilities[0, predicted].item()
    class_probs = [round(p * 100, 2) for p in probabilities[0].tolist()]
    class_labels = {0: 'Glioma', 1: 'Meningioma', 2: 'No Tumor', 3: 'Pituitary Tumor'}
    session['mri_results'] = {
        'class_0': class_probs[0],
        'class_1': class_probs[1],
        'class_2': class_probs[2],
        'class_3': class_probs[3],
        'predicted_class': class_labels[predicted.item()],
        'predicted_probability': round(predicted_probability * 100, 2)
    }
    # Generate histogram graph using matplotlib.
    fig, ax = plt.subplots()
    bars = ax.bar(list(class_labels.values()), class_probs, color='skyblue')
    ax.set_ylabel('Percentage')
    ax.set_title('MRI Classification Results')
    for bar, prob in zip(bars, class_probs):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{prob}%', ha='center', va='bottom', fontsize=10)
    buf2 = io.BytesIO()
    plt.savefig(buf2, format='png', bbox_inches='tight')
    plt.close(fig)
    buf2.seek(0)
    graph_base64 = base64.b64encode(buf2.getvalue()).decode()
    return render_template_string(result_template,
                                  patient_header=get_patient_header(),
                                  img_data=img_data,
                                  graph=graph_base64,
                                  class_0=class_probs[0],
                                  class_1=class_probs[1],
                                  class_2=class_probs[2],
                                  class_3=class_probs[3],
                                  predicted_class=class_labels[predicted.item()],
                                  predicted_probability=round(predicted_probability * 100, 2),
                                  common_js=common_js, help_button=help_button, exit_button=exit_button)

# ------------------- Report Generation (HTML Report) -------------------
@app.route('/generate_report')
def generate_report():
    # Only include sections if data was entered.
    patient_info = session.get('patient', 'Unknown Patient')
    medical_history = session.get('medical_history', {})
    combined_tests = session.get('combined_tests', {})
    mri_results = session.get('mri_results', {
        'class_0': 'N/A', 'class_1': 'N/A', 'class_2': 'N/A', 'class_3': 'N/A',
        'predicted_class': 'N/A', 'predicted_probability': 'N/A'
    })

    if "Glioma" in mri_results.get('predicted_class', ''):
        tumor_type = "Glioma"
    elif "Meningioma" in mri_results.get('predicted_class', ''):
        tumor_type = "Meningioma"
    elif "No Tumor" in mri_results.get('predicted_class', ''):
        tumor_type = "No Tumor"
    elif "Pituitary Tumor" in mri_results.get('predicted_class', ''):
        tumor_type = "Pituitary Tumor"
    else:
        tumor_type = "Unknown"

    tumor_explanations = {
        "Glioma": "Gliomas are tumors that originate from glial cells. They tend to be aggressive and may require comprehensive treatment including surgery, radiotherapy, and chemotherapy.",
        "Meningioma": "Meningiomas typically arise from the meninges. They are often benign but may cause symptoms due to their location, with treatment often involving surgical removal.",
        "No Tumor": "No tumor was detected; this is a favorable result, though routine monitoring is advised if symptoms persist.",
        "Pituitary Tumor": "Pituitary tumors occur in the pituitary gland and can affect hormone levels; further evaluation is recommended."
    }
    explanation = tumor_explanations.get(tumor_type, "No explanation available.")
    report_sections = []
    report_sections.append(f"<div class='section'><h2>Patient Information</h2><p>{patient_info}</p></div>")
    if medical_history:
        history_report = "".join(f"{q}: {medical_history[q]}<br>" for q in medical_history)
        report_sections.append(f"<div class='section'><h2>Medical History</h2><p>{history_report}</p></div>")
    if combined_tests:
        # Format each combined test response using the helper
        combined_report = ""
        for key, val in combined_tests.items():
            formatted = format_response(key, val)
            combined_report += f"{key}: {formatted}<br>"
        report_sections.append(f"<div class='section'><h2>Combined Test Results</h2><p>{combined_report}</p></div>")
    
    # Generate histogram for MRI classification.
    classes = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary Tumor']
    percentages = [mri_results.get('class_0', 0), mri_results.get('class_1', 0),
                   mri_results.get('class_2', 0), mri_results.get('class_3', 0)]
    fig, ax = plt.subplots()
    bars = ax.bar(classes, percentages, color='skyblue')
    ax.set_ylabel('Percentage')
    ax.set_title('MRI Classification Results')
    for bar, perc in zip(bars, percentages):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{perc}%', ha='center', va='bottom', fontsize=10)
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight')
    plt.close(fig)
    buf.seek(0)
    graph_base64 = base64.b64encode(buf.getvalue()).decode()
    mri_section = f"""
    <div class='section graph'>
      <h2>MRI Classification Results</h2>
      <img src="data:image/png;base64,{graph_base64}" alt="MRI Histogram" style="max-width:500px; display: block;">
      <p>
        Predicted Diagnosis: {mri_results.get('predicted_class')}<br>
        Confidence: {mri_results.get('predicted_probability')}%
      </p>
    </div>
    """
    report_sections.append(mri_section)
    report_sections.append(f"<div class='section'><h2>Tumor Explanation for {tumor_type}</h2><p>{explanation}</p></div>")
    report_sections.append("<div class='section'><p><i>Note: This prediction uses a pre-trained Vision Transformer (ViT) model—a neural network architecture that leverages self-attention mechanisms to capture global image features. Vision Transformers have demonstrated high precision in image analysis and classification.</i></p></div>")
    
    report_html = f"""
<html>
<head>
  <meta charset="utf-8">
  <title>Medical Report</title>
  <style>
    body {{ font-family: Arial, sans-serif; margin: 20px; }}
    h1 {{ text-align: left; }}
    .section {{ margin-bottom: 20px; }}
    .graph {{ text-align: left; }}
  </style>
</head>
<body>
  <h1>Medical Report</h1>
  {''.join(report_sections)}
</body>
</html>
    """
    return Response(report_html, mimetype="text/html", headers={"Content-Disposition": "attachment;filename=medical_report.html"})


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=False)
