In [2]:
import edgedb
import gradio as gr
from transformers import BertTokenizer, BertModel
import numpy as np
import os 



In [3]:
#connect to CLOUD edgedb database and test the connection
key='nbwt1_eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJlZGIuZC5hbGwiOnRydWUsImVkYi5pIjpbIjQ1OVByb2plY3QvbXlkYiJdLCJlZGIuci5hbGwiOnRydWUsImlhdCI6MTcxNzYxMTM0NywiaXNzIjoiYXdzLmVkZ2VkYi5jbG91ZCIsImp0aSI6Im9nWG9xQ05uRWUtZzdMZU93aWhhUnciLCJzdWIiOiI3Qk5NYWh3akVlLWlYNnM5LVRxaGdBIn0.zkNuJ4aIjhCwNOf8UDnC_YuJddnt16tapU0c7gJzfYclEBkKRABc5-6UPdqcvwYW-3W_nnYb70vcyX11_DuvYw'
os.environ['EDGEDB_CLIENT_TLS_SECURITY'] = 'insecure' #BC I couldn't get the right certificates
os.environ['EDGEDB_INSTANCE'] = '459Project/mydb'
os.environ['EDGEDB_SECRET_KEY'] = key
conn = edgedb.create_client()#dsn='edgedb://?branch=new')

In [5]:

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

def get_injury_data(location):
    jsonResponse = conn.query(
        """
        SELECT InjuryLocation {
            name,
            injuries: {
                name,
                symptoms: {
                    description
                },
                treatments: {
                    name,
                    requires_devices: {
                        name
                    }
                }
            }
        }
        FILTER .name = <str>$injury_name
        """,
        injury_name=location
    )
    #print(jsonResponse)
    return jsonResponse

def get_symptom_embedding(symptom):
    inputs = tokenizer(symptom, return_tensors='pt')
    outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).detach().numpy()

def find_best_match(location, symptoms):
    injury_data = get_injury_data(location)
    symptom_list = [s.strip() for s in symptoms.split(',')]
    
    best_matches = []
    
    for user_symptom in symptom_list:
        user_embedding = get_symptom_embedding(user_symptom)
        best_match = None
        best_score = float('inf')
        
        for injury in injury_data[0].injuries:
            for db_symptom in injury.symptoms:
                db_embedding = get_symptom_embedding(db_symptom.description)
                score = np.linalg.norm(user_embedding - db_embedding)
                if score < best_score:
                    best_score = score
                    best_match = injury
        
        best_matches.append(best_match) #
    
    best_injury = max(set(best_matches), key=best_matches.count)
    best_treatments = best_injury.treatments

    treatment_names = [treatment.name for treatment in best_treatments]
    medical_devices = set()
    
    for treatment in best_treatments:
        for device in treatment.requires_devices:
            medical_devices.add(device.name)
    
    #print(best_injury.symptoms)
    return best_injury.name, treatment_names, list(medical_devices), best_injury.symptoms

def process_injury(location, symptoms):
    best_injury, best_treatments, medical_devices, actual_sym = find_best_match(location, symptoms)
    typical_symptoms = []
    for s in actual_sym:
        typical_symptoms.append(s.description)
    
    return best_injury, ', '.join(best_treatments), ', '.join(medical_devices), ', '.join(typical_symptoms)


In [6]:
dropdown = gr.Dropdown(choices=["Shoulder", "Knee", "Feet"], label="Injury Location")
textbox = gr.Textbox(lines=2, placeholder="Enter symptoms separated by commas", label="Symptoms")

interface = gr.Interface(
    fn=process_injury, 
    inputs=[dropdown, textbox], 
    outputs=[gr.Textbox(label="Predicted Injury"), gr.Textbox(label="Treatment"), gr.Textbox(label="Medical Devices"), gr.Textbox(label='Typical Symptoms')], 
    title="Injury Predictor"
)

interface.launch(share=True, debug=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://8c99069a254e8d9477.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://8c99069a254e8d9477.gradio.live


