In [8]:
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.base import BaseEstimator, TransformerMixin

# Example data
data = pd.DataFrame({
    'age': [25, 40, 60, 20, 35, 45, 55, 30, 50, 38, 44, 28, 26, 52, 48, 33, 46, 29, 31, 36],
    'length': [165, 170, 175, 160, 180, 172, 169, 173, 168, 171, 174, 163, 167, 162, 177, 166, 178, 164, 161, 179],
    'diagnosis': [
        "Patient presents with fever, cough, and fatigue, suggestive of influenza.",
        "Abdominal pain, bloating, and diarrhea indicate possible gastroenteritis.",
        "Joint pain, swelling, and stiffness are common in rheumatoid arthritis.",
        "Skin rash, itching, and redness may be symptoms of eczema.",
        "Headache, dizziness, and blurred vision could indicate migraine.",
        "Chest pain, shortness of breath, and palpitations may signal heart disease.",
        "Excessive thirst, frequent urination, and fatigue are signs of diabetes.",
        "Mood swings, sadness, and loss of interest may indicate depression.",
        "Muscle weakness, tremors, and fatigue can occur in multiple sclerosis.",
        "Jaundice, abdominal pain, and dark urine suggest hepatitis.",
        "Frequent infections, fatigue, and swollen lymph nodes indicate HIV.",
        "Vision changes, eye pain, and sensitivity to light may suggest glaucoma.",
        "Chest tightness, wheezing, and coughing are symptoms of asthma.",
        "Frequent falls, memory loss, and confusion are signs of Alzheimer's disease.",
        "Back pain, numbness, and tingling could be due to sciatica.",
        "Swollen glands, sore throat, and fever may indicate strep throat.",
        "Excessive sweating, weight loss, and palpitations may point to hyperthyroidism.",
        "Abdominal discomfort, bloating, and changes in bowel movements suggest IBS.",
        "Joint stiffness, swelling, and limited range of motion are common in arthritis.",
        "Fever, chills, and body aches are typical of the flu."
    ]
})

# Custom transformer to select text column for TF-IDF
class TextSelector(BaseEstimator, TransformerMixin):
    def __init__(self, key):
        self.key = key

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        return X[self.key]

# Preprocessing pipeline for numerical features
numerical_features = ['age', 'length']
numerical_transformer = Pipeline(steps=[
    ('scaler', StandardScaler())
])

# Preprocessing pipeline for text features
text_transformer = Pipeline(steps=[
    ('selector', TextSelector(key='diagnosis')),
    ('tfidf', TfidfVectorizer(ngram_range=(1, 1), binary=True))
])

# Combine preprocessing pipelines
preprocessor = ColumnTransformer(
    transformers=[
        ('num', numerical_transformer, numerical_features),
        ('text', text_transformer, 'diagnosis')
    ])

# Function to train the Nearest Neighbors model
def train_nn_model(data):
    X = data[['age', 'length', 'diagnosis']]
    X_transformed = preprocessor.fit_transform(X)
    nn_model = NearestNeighbors(n_neighbors=5, algorithm='auto').fit(X_transformed)
    return nn_model

# Function to recommend similar diagnoses
def recommend_similar_diagnoses(model, data, age, length, diagnosis):
    input_data = pd.DataFrame({'age': [age], 'length': [length], 'diagnosis': [diagnosis]})
    input_transformed = preprocessor.transform(input_data)
    distances, indices = model.kneighbors(input_transformed)
    recommendations = [data['diagnosis'][idx] for idx in indices.flatten()]
    return recommendations


nn_model = train_nn_model(data)

# Get recommendations for new diagnosis descriptions
for case in new_cases:
    recommendations = recommend_similar_diagnoses(nn_model, data, case['age'], case['length'], case['diagnosis'])
    print(f"Diagnosis: '{case['diagnosis']}' => Similar Diagnoses:")
    for rec in recommendations:
        print(f" - {rec}")
    print()


KeyError: 'diagnosis'