In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [2]:
tokenizer=AutoTokenizer.from_pretrained("final-model")
model=AutoModelForSequenceClassification.from_pretrained("final-model")
id2label=model.config.id2label

In [3]:
def predict_disease(text):
    print(f"\nOriginal user input: '{text}'")
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(model.device)
    outputs = model(**inputs)
    probs = torch.softmax(outputs.logits, dim=-1).detach().cpu().numpy()[0]
    top3_idx = probs.argsort()[-3:][::-1]
    top3_diseases = [id2label[i] for i in top3_idx]
    top3_probs = [probs[i] for i in top3_idx]
    return list(zip(top3_diseases, top3_probs))

In [4]:
symptom_text = "Difficulty breathing, wheezing, chest tightness, and persistent cough especially at night"
print(predict_disease(symptom_text))


Original user input: 'Difficulty breathing, wheezing, chest tightness, and persistent cough especially at night'
[('asthma', np.float32(0.3598367)), ('acute respiratory distress syndrome (ards)', np.float32(0.20570575)), ('acute bronchospasm', np.float32(0.16959901))]


In [5]:
import pandas as pd
df=pd.read_csv("final_symptoms_to_disease.csv")

In [6]:
df["diseases"].value_counts()

diseases
cystitis                          1219
vulvodynia                        1218
nose disorder                     1218
complex regional pain syndrome    1217
spondylosis                       1216
                                  ... 
fracture of the arm                437
oppositional disorder              431
hypovolemia                        424
abdominal hernia                   407
inguinal hernia                    402
Name: count, Length: 254, dtype: int64

In [7]:
df["diseases"].unique()

array(['panic disorder', 'eye alignment disorder', 'vaginitis',
       'glaucoma', 'eating disorder', 'transient ischemic attack',
       'pyelonephritis', 'chronic pain disorder',
       'problem during pregnancy', 'choledocholithiasis',
       'diabetic retinopathy', 'fibromyalgia', 'acute pancreatitis',
       'thrombophlebitis', 'asthma', 'teething syndrome',
       'infectious gastroenteritis', 'acute sinusitis',
       'postpartum depression', 'spondylitis', 'uterine fibroids',
       'chalazion', 'vaginal yeast infection', 'ingrown toe nail',
       'corneal disorder', 'viral warts', 'stroke',
       'pelvic organ prolapse', 'fracture of the arm', 'hyperkalemia',
       'cornea infection', 'chronic sinusitis', 'conductive hearing loss',
       'abdominal hernia', 'marijuana abuse', 'indigestion', 'bursitis',
       'pulmonary congestion', 'actinic keratosis', 'acute otitis media',
       'chronic obstructive pulmonary disease (copd)', 'spondylosis',
       'herpangina', 'injury 

In [8]:
df[df["diseases"] == "heart attack"].nunique()

diseases          1
symptom_text    810
dtype: int64

In [9]:
from sklearn.utils import shuffle
df = shuffle(df).reset_index(drop=True)
df.head()

Unnamed: 0,diseases,symptom_text
0,diaper rash,"irritable infant, vomiting, fever, temper prob..."
1,chronic obstructive pulmonary disease (copd),"shortness of breath, sharp chest pain, chest t..."
2,brachial neuritis,"headache, neck pain, low back pain, elbow pain..."
3,fracture of the leg,"hip pain, knee pain, knee swelling, leg stiffn..."
4,marijuana abuse,"anxiety and nervousness, depressive or psychot..."


In [10]:
def predict_disease(text):
    print(f"\nOriginal user input: '{text}'")
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(model.device)
    outputs = model(**inputs)
    probs = torch.softmax(outputs.logits, dim=-1).detach().cpu().numpy()[0]
    top3_idx = probs.argsort()[-3:][::-1]
    top3_diseases = [id2label[i] for i in top3_idx]
    top3_probs = [probs[i] for i in top3_idx]
    return list(zip(top3_diseases, top3_probs))

In [12]:
def run_test_cases():
	"""
	Run predictions for all test cases using predict_disease (defined in notebook),
	and calculate Top-1 and Top-3 accuracy.
	"""
	top1_correct = 0
	top3_correct = 0
	total = 100

	for i in range(1, total+1):
		symptoms=df["symptom_text"][i-1]  # Get symptoms from dataframe
		expected=df["diseases"][i-1]  # Get expected disease from dataframe
		try:
			preds = predict_disease(symptoms)  # returns list of (disease, prob)
			if not preds:
				print(f"❌ {i}. {expected}: no predictions")
				continue

			top1 = preds[0][0].lower()
			top3 = [p[0].lower() for p in preds]
			expected_lower = expected.lower()

			is_top1 = expected_lower in top1 or top1 in expected_lower
			is_top3 = any(expected_lower in d or d in expected_lower for d in top3)

			if is_top1:
				top1_correct += 1
				top3_correct += 1
				print(f"✅ {i}. {expected}: TOP-1 -> {preds[0][0]}")
			elif is_top3:
				top3_correct += 1
				print(f"⚠️  {i}. {expected}: in TOP-3 -> {top3}")
			else:
				print(f"❌ {i}. {expected}: TOP-1 -> {preds[0][0]} | TOP-3 -> {top3}")

		except Exception as e:
			print(f"❌ {i}. {expected}: Error -> {e}")

	top1_acc = (top1_correct / total) * 100 if total else 0.0
	top3_acc = (top3_correct / total) * 100 if total else 0.0

	print("\n" + "="*40)
	print(f"RESULTS: Total={total}")
	print(f"Top-1 Accuracy: {top1_acc:.2f}% ({top1_correct}/{total})")
	print(f"Top-3 Accuracy: {top3_acc:.2f}% ({top3_correct}/{total})")
	print("="*40)

	return top1_acc, top3_acc

# Run evaluation on the provided test_cases
run_test_cases()


Original user input: 'irritable infant, vomiting, fever, temper problems, skin rash, diaper rash'
✅ 1. diaper rash: TOP-1 -> diaper rash

Original user input: 'shortness of breath, sharp chest pain, chest tightness, sore throat, nasal congestion, fever, congestion in chest'
✅ 2. chronic obstructive pulmonary disease (copd): TOP-1 -> chronic obstructive pulmonary disease (copd)

Original user input: 'headache, neck pain, low back pain, elbow pain, paresthesia, shoulder pain'
✅ 3. brachial neuritis: TOP-1 -> brachial neuritis

Original user input: 'hip pain, knee pain, knee swelling, leg stiffness or tightness'
✅ 4. fracture of the leg: TOP-1 -> fracture of the leg

Original user input: 'anxiety and nervousness, depressive or psychotic symptoms, difficulty speaking, abusing alcohol, hostile behavior, delusions or hallucinations, fears and phobias'
✅ 5. marijuana abuse: TOP-1 -> marijuana abuse

Original user input: 'diminished vision, double vision, symptoms of eye, pain in eye, abnorma

(88.0, 98.0)