In [None]:
#pip install -U transformers datasets accelerate bitsandbytes peft

import pandas as pd
from datasets import Dataset
from sklearn.preprocessing import LabelEncoder
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
import torch.nn.functional as F

disease_df = pd.read_csv("disease.csv")
drug_df = pd.read_csv("drug.csv")

le = LabelEncoder()
disease_df["label_id"] = le.fit_transform(disease_df["label"])
num_labels = len(le.classes_)

hf_dataset = Dataset.from_pandas(disease_df[["text", "label_id"]])
hf_dataset = hf_dataset.train_test_split(test_size=0.2)

model_name = "bigscience/bloom-560m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, device_map="auto")

def tokenize(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=128)

tokenized = hf_dataset.map(tokenize, batched=True)
tokenized = tokenized.remove_columns(["text"])
tokenized = tokenized.rename_column("label_id", "labels")
tokenized.set_format("torch")

training_args = TrainingArguments(
    output_dir="/content/bloom-disease-model",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    weight_decay=0.01,
    report_to="none"
)

# Define a custom Trainer to handle the unexpected 'num_items_in_batch' argument
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=0): # Added num_items_in_batch parameter
        if "num_items_in_batch" in inputs:
            inputs.pop("num_items_in_batch") # Remove the unexpected argument
        return super().compute_loss(model, inputs, return_outputs)

# Use the CustomTrainer instead of the base Trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["test"],
    tokenizer=tokenizer,
)

trainer.train()

def predict_top3(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True).to(model.device)
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
    ids = probs.argsort()[-3:][::-1]
    return [(le.classes_[i], float(probs[i])) for i in ids]

def get_drugs(disease):
    meds = drug_df[drug_df["disease"] == disease]["drug"].tolist()
    return meds if meds else ["No drug info"]

def diagnose(text):
    preds = predict_top3(text)
    return [{"disease": d, "probability": p, "medications": get_drugs(d)} for d, p in preds]


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


KeyboardInterrupt: 

In [None]:
def chat():
    while True:
        text = input("You: ")
        if text.lower() in ["exit", "quit"]:
            print("Bot: Bye")
            break
        res = diagnose(text)
        print("\nBot:")
        for r in res:
            print(f"- {r['disease']} ({r['probability']:.2f})")
            print("  Drugs:", ", ".join(r["medications"]))
        print()

chat()


You: I have burning pain in my stomach

Bot:
- peptic ulcer disease (0.99)
  Drugs: No drug info
- Acne (0.00)
  Drugs: doxycycline, spironolactone, minocycline, clindamycin, tretinoin, isotretinoin, benzoyl peroxide, adapalene, tetracycline, cephalexin, sulfamethoxazole / trimethoprim, doxycycline, spironolactone, minocycline, clindamycin, tretinoin, isotretinoin, benzoyl peroxide, adapalene, tetracycline, cephalexin, sulfamethoxazole / trimethoprim, doxycycline, spironolactone, minocycline, clindamycin, tretinoin, isotretinoin, benzoyl peroxide, adapalene, tetracycline, cephalexin, sulfamethoxazole / trimethoprim, doxycycline, spironolactone, minocycline, clindamycin, tretinoin, isotretinoin, benzoyl peroxide, adapalene, tetracycline, cephalexin, sulfamethoxazole / trimethoprim, doxycycline, spironolactone, minocycline, clindamycin, tretinoin, isotretinoin, benzoyl peroxide, adapalene, tetracycline, cephalexin, sulfamethoxazole / trimethoprim, doxycycline, spironolactone, minocycline