In [1]:
%pip install flask flask-cors torch torchvision

Note: you may need to restart the kernel to use updated packages.


In [2]:
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from GlaucomaModel import UNet, vertical_cup_to_disc_ratio, refine_seg

app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "*"}})

# Load the pre-trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=3, n_classes=2).to(device)
checkpoint = torch.load('unet_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Threshold for vCDR classification
vCDR_threshold = 0.6

# Preprocessing function for input images
def preprocess(image):
    # Ensure the image is in RGB format, regardless of original format
    image = image.convert('RGB')

    # Resize image to a fixed size while preserving aspect ratio and padding
    size = (256, 256)
    transform = transforms.Compose([
        transforms.Resize(256),  # Resize the shorter side to 256
        transforms.CenterCrop(size),  # Crop the center to get a 256x256 image
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalization for pre-trained models
    ])

    return transform(image)

@app.route('/api/calculate_vcdr', methods=['POST'])
def calculate_vcdr():
    if 'images' not in request.files:
        return jsonify({"error": "No files uploaded"}), 400

    files = request.files.getlist('images')
    if not files:
        return jsonify({"error": "No files received"}), 400

    results = []

    for file in files:
        try:
            # Load and preprocess image
            image = Image.open(file).convert('RGB')
            input_tensor = preprocess(image).unsqueeze(0).to(device)

            # Forward pass through the model
            with torch.no_grad():
                logits = model(input_tensor)

            # Get segmentation predictions for OD and OC
            pred_od = refine_seg((logits[:, 0, :, :] >= 0.5).type(torch.int8).cpu()).to(device)
            pred_oc = refine_seg((logits[:, 1, :, :] >= 0.5).type(torch.int8).cpu()).to(device)

            # Compute vCDR
            pred_vCDR = vertical_cup_to_disc_ratio(pred_od.cpu().numpy(), pred_oc.cpu().numpy())[0]

            # Classify based on vCDR threshold
            predicted_label = "Glaucoma" if pred_vCDR > vCDR_threshold else "No Glaucoma"

            # Append the result
            results.append({
                "file": file.filename,
                "vCDR": f"{pred_vCDR:.2f}",
                "prediction": predicted_label
            })

        except Exception as e:
            print(f"Error processing file {file.filename}: {str(e)}")
            results.append({"file": file.filename, "error": str(e)})

    return jsonify({"results": results}), 200

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)


  checkpoint = torch.load('unet_model.pth', map_location=device)


 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://192.168.1.216:5000
Press CTRL+C to quit
127.0.0.1 - - [01/Dec/2024 13:06:17] "POST /api/calculate_vcdr HTTP/1.1" 200 -
