In [3]:
from transformers import BertForSequenceClassification, BertTokenizer
import faiss
import pickle

# Load model & tokenizer
model = BertForSequenceClassification.from_pretrained("saved_model/bert_disease_classifier")
tokenizer = BertTokenizer.from_pretrained("saved_model/bert_disease_classifier")

# Load FAISS
index = faiss.read_index("saved_model/faiss_index.index")

# Load records and label encoder
with open("saved_model/full_records.pkl", "rb") as f:
    full_records = pickle.load(f)

with open("saved_model/label_encoder.pkl", "rb") as f:
    label_encoder = pickle.load(f)


In [5]:
import torch

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [18]:
def predict_disease(symptom_input):
    model.eval()
    encoding = tokenizer(symptom_input, return_tensors="pt", truncation=True, padding='max_length', max_length=64)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        pred_label = torch.argmax(outputs.logits, dim=1).item()
    
    predicted_disease = label_encoder.inverse_transform([pred_label])[0]
    return predicted_disease


In [19]:
from sentence_transformers import SentenceTransformer
embedder = SentenceTransformer('all-MiniLM-L6-v2')

In [20]:
def retrieve_full_info(symptom_input, k=3):
    input_embedding = embedder.encode([symptom_input]).astype('float32')
    D, I = index.search(input_embedding, k)
    results = [full_records[i] for i in I[0]]
    return results


In [21]:
def full_pipeline(symptom_list):
    input_text = ", ".join(symptom_list)
    
    predicted_disease = predict_disease(input_text)
    retrieved_entries = retrieve_full_info(input_text, k=3)

    # Optionally filter based on predicted disease
    for entry in retrieved_entries:
        if entry["Disease"] == predicted_disease:
            return entry
    
    # Fallback
    return {"Disease": predicted_disease, "entry": entry}


In [22]:
import json

In [23]:
symptoms = ["Burning stomach pain, bloating, nausea"]#, "Sore throat", "Cough"]
result = full_pipeline(symptoms)
print(json.dumps(result, indent=2))

{
  "Disease": "Gastritis",
  "Symptoms": [
    "Stomach pain",
    "bloating",
    "nausea"
  ],
  "Medicines": [
    "Omeprazole",
    "Ranitidine"
  ],
  "Brand Names": [
    "Omez",
    "Zinetac"
  ],
  "Dosages": [
    "20mg",
    "150mg"
  ],
  "Prices (INR)": [
    "30",
    "20"
  ]
}
