In [2]:
import edgedb
import gradio as gr
import gradio as gr
import edgedb
from transformers import BertTokenizer, BertModel
import torch
import numpy as np

conn = edgedb.create_client()

In [7]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

def get_injury_data(location):
    jsonResponse = conn.query(
        """
        SELECT InjuryLocation {
            name,
            injuries: {
                name,
                symptoms: {
                    description
                },
                treatments: {
                    name,
                    requires_devices: {
                        name
                    }
                }
            }
        }
        FILTER .name = <str>$injury_name
        """,
        injury_name=location
    )
    return jsonResponse

def get_symptom_embedding(symptom):
    inputs = tokenizer(symptom, return_tensors='pt')
    outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).detach().numpy()

def find_best_match(location, symptoms):
    injury_data = get_injury_data(location)
    symptom_list = [s.strip() for s in symptoms.split(',')]
    
    best_matches = []
    
    for user_symptom in symptom_list:
        user_embedding = get_symptom_embedding(user_symptom)
        best_match = None
        best_score = float('inf')
        
        for injury in injury_data[0].injuries:
            for db_symptom in injury.symptoms:
                db_embedding = get_symptom_embedding(db_symptom.description)
                score = np.linalg.norm(user_embedding - db_embedding)
                if score < best_score:
                    best_score = score
                    best_match = injury
        
        best_matches.append(best_match)
    
    best_injury = max(set(best_matches), key=best_matches.count)
    best_treatments = best_injury.treatments

    treatment_names = [treatment.name for treatment in best_treatments]
    medical_devices = set()
    for treatment in best_treatments:
        for device in treatment.requires_devices:
            medical_devices.add(device.name)
    
    return best_injury.name, treatment_names, list(medical_devices)

def process_injury(location, symptoms):
    best_injury, best_treatments, medical_devices = find_best_match(location, symptoms)
    return best_injury, ', '.join(best_treatments), ', '.join(medical_devices)

dropdown = gr.Dropdown(choices=["Shoulder", "Knee", "Feet"], label="Injury Location")
textbox = gr.Textbox(lines=2, placeholder="Enter symptoms separated by commas", label="Symptoms")

interface = gr.Interface(
    fn=process_injury, 
    inputs=[dropdown, textbox], 
    outputs=[gr.Textbox(label="Predicted Injury"), gr.Textbox(label="Treatment"), gr.Textbox(label="Medical Devices")], 
    title="Injury Predictor"
)

interface.launch(share=True)

Running on local URL:  http://127.0.0.1:7861
IMPORTANT: You are using gradio version 4.20.1, however version 4.29.0 is available, please upgrade.
--------
Running on public URL: https://95a36a7ba32b8dffcc.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


