<a href="https://colab.research.google.com/github/RanjanTarun27/text-to-disease/blob/main/text_to_disease.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

!pip install -q transformers torch gradio matplotlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
import gradio as gr
import numpy as np

class ResumeReadyClassifier(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')


        self.attention = nn.Sequential(
            nn.Linear(768, 256),
            nn.Tanh(),
            nn.Linear(256, 1)
        )


        self.dropouts = nn.ModuleList([nn.Dropout(0.3) for _ in range(5)])
        self.classifier = nn.Linear(768, n_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state


        attn_weights = F.softmax(self.attention(sequence_output), dim=1)
        context_vector = torch.sum(attn_weights * sequence_output, dim=1)


        logits = torch.mean(torch.stack([
            self.classifier(d(context_vector)) for d in self.dropouts
        ], dim=0), dim=0)

        return logits, attn_weights


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TOKENIZER = BertTokenizer.from_pretrained('bert-base-cased')


CLASSES = ["Covid-19", "Influenza", "Common Cold", "Pneumonia", "Bronchitis", "Allergy"]
model = ResumeReadyClassifier(n_classes=len(CLASSES)).to(DEVICE).eval()


def predict_symptoms(text):
    inputs = TOKENIZER(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)

    with torch.no_grad():
        logits, weights = model(inputs['input_ids'], inputs['attention_mask'])
        probabilities = F.softmax(logits, dim=1).flatten()


    confidences = {CLASSES[i]: float(probabilities[i]) for i in range(len(CLASSES))}


    tokens = TOKENIZER.convert_ids_to_tokens(inputs['input_ids'][0])
    attn_scores = weights[0].cpu().numpy().flatten()


    word_importance = ""
    for token, score in zip(tokens, attn_scores):
        if token not in ['[CLS]', '[SEP]', '[PAD]']:
            word_importance += f"{token} ({score:.2f})  "

    return confidences, f"Key Symptoms Detected: {word_importance}"


desc = """
## ðŸ©º Advanced Medical Symptom Classifier
**Features:** BERT-Base-Cased, Self-Attention Mechanism, Multi-Sample Dropout.
*Type symptoms below to see the model's diagnosis and its confidence levels.*
"""

interface = gr.Interface(
    fn=predict_symptoms,
    inputs=gr.Textbox(lines=3, placeholder="Describe symptoms (e.g., High fever, dry cough, and fatigue)..."),
    outputs=[gr.Label(num_top_classes=3, label="Top Diagnoses"), gr.Markdown(label="Explainability")],
    title="Clinical Decision Support LLM",
    description=desc,
    theme="soft",
    examples=[
        ["Sudden loss of taste and smell with a mild fever."],
        ["Wheezing and shortness of breath after exercise."]
    ]
)


interface.launch(share=True, debug=True)