In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.metrics import classification_report, accuracy_score

In [2]:
df = pd.read_csv('symptoms_data/refined_synthetic_dental_dataset.csv')
df

Unnamed: 0,Description,label
0,Persistent tooth pain likely from caries,caries
1,Potential tooth loss due to swollen gums,gingivitis
2,Persistent mouth ulcers causing irritation,mouth_ulcer
3,Teeth becoming yellow due to natural aging,tooth_discoloration
4,Looking for solutions for yellowed teeth,tooth_discoloration
...,...,...
995,Persistent mouth ulcers causing irritation,mouth_ulcer
996,Teeth becoming yellow due to natural aging,tooth_discoloration
997,Persistent mouth ulcers causing irritation,mouth_ulcer
998,Gum inflammation noticed with signs of gingivitis,gingivitis


In [3]:
# Split the data
X = df["Description"]
y = df["label"]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

# Convert text to features
vectorizer = TfidfVectorizer()  # You can switch to CountVectorizer() if needed
X_train_vec = vectorizer.fit_transform(X_train)
X_test_vec = vectorizer.transform(X_test)

In [4]:
from sklearn.naive_bayes import MultinomialNB

# Train Naive Bayes model
nb_model = MultinomialNB()
nb_model.fit(X_train_vec, y_train)

# Evaluate
y_pred_nb = nb_model.predict(X_test_vec)
print("Naive Bayes Results:")
print(classification_report(y_test, y_pred_nb))
print("Accuracy:", accuracy_score(y_test, y_pred_nb))

Naive Bayes Results:
                     precision    recall  f1-score   support

             caries       1.00      1.00      1.00        60
         gingivitis       1.00      1.00      1.00        60
         hypodontia       1.00      1.00      1.00        60
        mouth_ulcer       1.00      1.00      1.00        60
tooth_discoloration       1.00      1.00      1.00        60

           accuracy                           1.00       300
          macro avg       1.00      1.00      1.00       300
       weighted avg       1.00      1.00      1.00       300

Accuracy: 1.0


In [5]:
from joblib import dump, load

In [6]:
saved_model_path = 'models/nb_model.joblib'
vectorizer_model_path = 'models/vectorizer.joblib'
try:
    dump(value=nb_model, filename=saved_model_path)
    print(f"Model has been saved to path: {saved_model_path}")

    dump(vectorizer, vectorizer_model_path)
    print(f"Vectorizer has also been saved to path: {vectorizer_model_path}")
except:
    print("Some exception occured!")

Model has been saved to path: models/nb_model.joblib
Vectorizer has also been saved to path: models/vectorizer.joblib


### Creating a function to classify text

In [7]:
from joblib import load
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB

# Load saved model and vectorizer
saved_nb_model = load(filename='models/nb_model.joblib')
saved_vectorizer = load(filename='models/vectorizer.joblib')

def nb_predict(text: str, model: MultinomialNB = saved_nb_model, vectorizer: TfidfVectorizer = saved_vectorizer):
    """
    Takes in an entire text, separates it into segments of text, 
    and identifies the symptoms of oral disease.

    Parameters:
    - text (str): The input text to analyze.
    - model (MultinomialNB): The trained Naive Bayes model.
    - vectorizer (TfidfVectorizer): The fitted TfidfVectorizer for text transformation.

    Returns:
    - set: A set of predicted disease categories.
    """
    # Split the input text into sentences or segments
    text_list = text.split(". ")
    output_set = set()

    for txt in text_list:
        # Transform the text using the vectorizer
        transformed_text = vectorizer.transform([txt])  # Input must be a list
        # Predict the class
        pred_class = model.predict(transformed_text)[0]  # Get the first (and only) prediction
        # Add the prediction to the set
        output_set.add(pred_class)

    return output_set

In [8]:
nb_predict("I have cavity. My gums are getting red. My teeth are yellow in color.")

{'caries', 'gingivitis', 'tooth_discoloration'}