In [11]:
import tensorflow as tf
import numpy as np
import json

In [12]:
def load_model_and_mappings(model_path, mappings):
    try:
        model = tf.keras.models.load_model(model_path)
    except Exception as e:
        print(f"Error loading model or tokenizer, make sure the model's path does exist: {e}")
        return None, None, None, None

    try:
        with open(mappings, 'r') as f:
            specialty_and_id_map = json.load(f)
    except:
        print(f"Data not found, make sure to run the specialty_data_preprocessing.ipynb file in its entirety to retrieve the data")

    id_to_label = {int(k): v for k, v in specialty_and_id_map['id_to_label'].items()}
    label_to_id = specialty_and_id_map['label_to_id']
        
    return model, id_to_label, label_to_id

def predict_specialty(text_input, model_path, mappings):

    model, id_to_label, _ = load_model_and_mappings(model_path, mappings)
    if model is None:
        return None, None
    
    text_input_tensor = tf.constant([text_input])
    
    probabilities = model.predict(text_input_tensor)[0]
    
    predicted_class_id = np.argmax(probabilities)
    confidence = probabilities[predicted_class_id]
    
    predicted_specialty = id_to_label.get(predicted_class_id, "Unknown Label")
    
    return predicted_specialty, confidence


In [13]:
MAPPINGS = '../../Data/Specialty-Data/specialty_data_label_mappings.json'
# Use absolute file path
# Change MODEL to Model you want to test
MODEL = "./Saved-Models/RNN/base_rnn.keras"
# MODEL = "./Saved-Models/RNN/glove_rnn.keras"

TRANSCRIPTION = """
2-D M-MODE: , ,1.  Left atrial enlargement with left atrial diameter of 4.7 cm.,2.  Normal size right and left ventricle.,3.  Normal LV systolic function with left ventricular ejection fraction of 51%.,4.  Normal LV diastolic function.,5.
No pericardial effusion.,6.  Normal morphology of aortic valve, mitral valve, tricuspid valve, and pulmonary valve.,7.  PA systolic pressure is 36 mmHg.,DOPPLER: , ,1.  Mild mitral and tricuspid regurgitation.,2.  Trace aortic and pulmonary regurgitation.
"""

specialty, confidence = predict_specialty(TRANSCRIPTION, MODEL, MAPPINGS)

print(f"PREDICTION")
print("-------------------------------------")
print(f"Specialty:   {specialty}")
print(f"Confidence:  {confidence*100:.2f}%")
print("-------------------------------------")

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
PREDICTION
-------------------------------------
Specialty:    Cardiovascular / Pulmonary
Confidence:  84.60%
-------------------------------------
