✅ Step 1: Setup & Data Preparation

In [None]:
# Basic Libraries
import pandas as pd
import numpy as np
import os
import random
import tensorflow as tf

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Load dataset
file_path = '/content/drive/MyDrive/Colab Notebooks/Medical_chatbot/medical_qa_doctor_style_refined.csv'
df = pd.read_csv(file_path)

# Clean dataset
df.drop_duplicates(inplace=True)
df['question'] = df['question'].str.strip()
df['answer'] = df['answer'].str.strip()
df.dropna(inplace=True)
df.reset_index(drop=True, inplace=True)

# Preview
print(f"Total cleaned samples: {len(df)}")
df = df.rename(columns={'question': 'input_text', 'answer': 'target_text'})
df['input_text'] = 'healthcare question: ' + df['input_text']


Mounted at /content/drive
Total cleaned samples: 285


✅ Step 2: Train/Val/Test Split

In [None]:
from sklearn.model_selection import train_test_split

# 80-10-10 split
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

# Convert to Hugging Face Datasets
from datasets import Dataset, DatasetDict
dataset = DatasetDict({
    'train': Dataset.from_pandas(train_df),
    'validation': Dataset.from_pandas(val_df),
    'test': Dataset.from_pandas(test_df)
})


✅ Step 3: Tokenization

In [None]:
from transformers import AutoTokenizer

model_checkpoint = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

max_input_length = 128
max_target_length = 128

def preprocess(example):
    inputs = tokenizer(example['input_text'], max_length=max_input_length, padding="max_length", truncation=True)
    targets = tokenizer(example['target_text'], max_length=max_target_length, padding="max_length", truncation=True)
    inputs['labels'] = targets['input_ids']
    return inputs

tokenized_datasets = dataset.map(preprocess, batched=True, remove_columns=dataset['train'].column_names)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Map:   0%|          | 0/228 [00:00<?, ? examples/s]

Map:   0%|          | 0/28 [00:00<?, ? examples/s]

Map:   0%|          | 0/29 [00:00<?, ? examples/s]

✅ Step 4: Model Setup

In [None]:
from transformers import TFAutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, create_optimizer

# Load model
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, return_tensors="tf")

# Training parameters
batch_size = 8
epochs = 30
learning_rate = 5e-5
num_train_steps = (len(tokenized_datasets['train']) // batch_size) * epochs
optimizer, schedule = create_optimizer(init_lr=learning_rate, num_warmup_steps=0, num_train_steps=num_train_steps)


config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.

All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


✅ Step 5: TF Dataset Creation

In [None]:
def create_tf_dataset_from_hf(dataset, data_collator, batch_size, shuffle=False):
    examples = [{
        "input_ids": example["input_ids"],
        "attention_mask": example["attention_mask"],
        "labels": example["labels"]
    } for example in dataset]

    def data_generator():
        indices = list(range(len(examples)))
        if shuffle:
            np.random.shuffle(indices)
        for i in range(0, len(indices), batch_size):
            batch = [examples[j] for j in indices[i:i + batch_size]]
            collated = data_collator(batch)
            yield (
                {"input_ids": np.array(collated["input_ids"]), "attention_mask": np.array(collated["attention_mask"])},
                np.array(collated["labels"])
            )

    return tf.data.Dataset.from_generator(
        data_generator,
        output_signature=(
            {
                "input_ids": tf.TensorSpec(shape=(None, None), dtype=tf.int32),
                "attention_mask": tf.TensorSpec(shape=(None, None), dtype=tf.int32)
            },
            tf.TensorSpec(shape=(None, None), dtype=tf.int32)
        )
    )

# Build TF datasets
tf_train_dataset = create_tf_dataset_from_hf(tokenized_datasets["train"], data_collator, batch_size, shuffle=True)
tf_val_dataset = create_tf_dataset_from_hf(tokenized_datasets["validation"], data_collator, batch_size, shuffle=False)


✅ Step 6: Model Training & Saving

In [None]:
# Compile and train
model.compile(optimizer=optimizer)
model.fit(tf_train_dataset, validation_data=tf_val_dataset, epochs=epochs)

# Save model
output_dir = "/content/drive/MyDrive/Colab Notebooks/Medical_chatbot/healthcare-chatbot-model"
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")


Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
Model saved to /content/drive/MyDrive/Colab Notebooks/Medical_chatbot/healthcare-chatbot-model


✅ Step 7: Inference Function

In [None]:
# Keywords for filtering
medical_keywords = [
    "symptom", "diagnose", "treatment", "medicine", "disease", "doctor",
    "covid", "cancer", "diabetes", "bipolar", "stroke", "fever", "infection",
    "pain", "mental", "health", "hospital", "vaccine", "prescription",
    "disorder", "diagnosed", "asthma", "epilepsy", "hypertension",
    "depression", "anxiety", "hiv", "ibuprofen", "lisinopril", "side effects",
    "paracetamol", "atorvastatin", "metformin", "checkup", "healthy lifestyle",
    "symptoms", "water", "dose", "blood pressure", "heart", "immune",
    "medication", "mental health", "therapy"
]

def is_medical_question(question):
    return any(keyword in question.lower() for keyword in medical_keywords)

def generate_answer(question):
    if not is_medical_question(question):
        return "❗ Sorry, I can only answer healthcare-related questions."
    input_text = "healthcare question: " + question
    input_ids = tokenizer(input_text, return_tensors="tf", padding=True, truncation=True).input_ids
    output = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Test examples
test_questions = [
    "What are the symptoms of stroke?",
    "Can bipolar disorder be detected early?",
    "How is COVID-19 diagnosed?",
]

for q in test_questions:
    print(f"\n❓ Question: {q}")
    print(f"💬 Answer: {generate_answer(q)}")



❓ Question: What are the symptoms of stroke?
💬 Answer: Symptoms of stroke include shortness of breath, fatigue. this information is helpful for understanding the condition better. understanding this response helps in gaining deeper insight into the medical condition and encourages timely medical consultation. If you have any concerns or symptoms, it's important to follow up with a healthcare provider for a personalized evaluation. If you have any concerns or symptoms, it's important to follow up with a healthcare provider for a personalized evaluation.

❓ Question: Can bipolar disorder be detected early?
💬 Answer: Yes, certainly, early detection of bipolar disorder is possible through questionnaires, ecg. this information is helpful for understanding the condition better. healthcare providers use behavioral assessments, interviews, and family history to evaluate and detect mood disorders at an early stage, potentially preventing more severe episodes. If you have any concerns or sympto

# GRADIO UI

In [None]:
import gradio as gr
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer
import random

# Load fine-tuned model and tokenizer
model_path = "/content/drive/MyDrive/healthcare-chatbot-model"  # Update if needed
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Medical keywords for filtering
medical_keywords = [
    "symptom", "diagnose", "treatment", "medicine", "disease", "doctor",
    "covid", "cancer", "diabetes", "bipolar", "stroke", "fever", "infection",
    "pain", "mental", "health", "hospital", "vaccine", "prescription"
]

# Greeting keywords
greeting_keywords = [
    "hello", "hi", "hey", "good morning", "good afternoon", "good evening",
    "howdy", "greetings", "what's up", "whats up", "how are you", "sup"
]

# Greeting responses
greeting_responses = [
    "Hello! 👋 I'm your medical assistant. How can I help you with your health questions today?",
    "Hi there! 🩺 I'm here to help with any medical or health-related questions you might have.",
    "Greetings! 😊 I'm your healthcare chatbot. Feel free to ask me about symptoms, treatments, or general health information.",
    "Hello! 🏥 Nice to meet you! I'm ready to assist with your medical inquiries.",
    "Hi! 👨‍⚕️ I'm your AI medical assistant. What health topic would you like to discuss today?"
]

fun_facts = [
    "Did you know? The human brain has around 86 billion neurons!",
    "Fun fact: Laughing is good for your heart and can reduce stress.",
    "Tip: Drinking water can improve cognitive performance."
]

def is_greeting(q):
    return any(keyword in q.lower() for keyword in greeting_keywords)

def is_medical_question(q):
    return any(keyword in q.lower() for keyword in medical_keywords)

def generate_answer(question):
    question_lower = question.lower().strip()

    # Handle greetings
    if is_greeting(question_lower):
        return random.choice(greeting_responses)

    # Handle medical questions
    if is_medical_question(question_lower):
        input_text = "healthcare question: " + question
        input_ids = tokenizer(input_text, return_tensors="tf", padding=True, truncation=True).input_ids

        output = model.generate(
            input_ids,
            max_length=128,
            num_beams=4,
            temperature=0.7,
            top_k=50,
            top_p=0.95,
            early_stopping=True
        )
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        return response[0].upper() + response[1:]  # Capitalize first letter

    # Handle non-medical questions
    return random.choice(fun_facts) + "\n\n🩺 Please ask a healthcare-related question or feel free to greet me!"

# Enhanced chatbot function with immediate message display
def chatbot_respond(message, history):
    if message.strip() == "":
        return history, ""

    # Add user message immediately to chat
    new_history = history + [(message, "Thinking...")]

    # Generate response
    response = generate_answer(message)

    # Update the last entry with actual response
    new_history[-1] = (message, response)

    return new_history, ""

# Function for example questions with immediate display
def handle_example_question(question, history):
    # Add question immediately to chat
    new_history = history + [(question, "Thinking...")]

    # Generate response
    response = generate_answer(question)

    # Update with actual response
    new_history[-1] = (question, response)

    return new_history

def clear_chat():
    return []

# Enhanced CSS with medical color scheme
medical_css = """
/* Medical color scheme with professional appearance */
body {
    background: linear-gradient(135deg, #e8f5e8 0%, #f0f8ff 100%) !important;
    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
    padding: 20px !important;
}

.gradio-container {
    background: linear-gradient(135deg, #e8f5e8 0%, #f0f8ff 100%) !important;
    max-width: 1200px !important;
    margin: 0 auto !important;
    padding: 20px !important;
    border-radius: 20px !important;
    box-shadow: 0 10px 30px rgba(0,0,0,0.1) !important;
}

/* Header styling */
.markdown h1 {
    color: #2c5530 !important;
    text-shadow: 2px 2px 4px rgba(0,0,0,0.1) !important;
}

/* Chat container styling */
.chatbot {
    background: #ffffff !important;
    border: 2px solid #4a90a4 !important;
    border-radius: 15px !important;
    box-shadow: 0 4px 15px rgba(74, 144, 164, 0.2) !important;
}

/* Chat messages styling - WhatsApp-like */
.message.user {
    background: #dcf8c6 !important;
    border-radius: 18px 18px 4px 18px !important;
    margin: 5px 0 !important;
    padding: 8px 12px !important;
    max-width: 80% !important;
    margin-left: auto !important;
    border: 1px solid #b7e5a1 !important;
    word-wrap: break-word !important;
    white-space: normal !important;
    overflow-wrap: break-word !important;
}

.message.bot {
    background: #ffffff !important;
    border: 1px solid #e0e0e0 !important;
    border-radius: 18px 18px 18px 4px !important;
    margin: 5px 0 !important;
    padding: 8px 12px !important;
    max-width: 80% !important;
    margin-right: auto !important;
    box-shadow: 0 1px 2px rgba(0,0,0,0.1) !important;
    word-wrap: break-word !important;
    white-space: normal !important;
    overflow-wrap: break-word !important;
}

/* Fix for chat text display */
.chatbot .message, .chatbot .message p {
    white-space: normal !important;
    word-wrap: break-word !important;
    overflow-wrap: break-word !important;
    word-break: normal !important;
    hyphens: none !important;
    writing-mode: horizontal-tb !important;
    text-orientation: mixed !important;
}

/* Ensure proper text flow in chat bubbles */
.chatbot .wrap {
    white-space: normal !important;
    word-wrap: break-word !important;
}

.chatbot .message-wrap {
    white-space: normal !important;
    word-wrap: break-word !important;
    display: block !important;
}

/* Input field styling */
.textbox input {
    background: #ffffff !important;
    border: 2px solid #4a90a4 !important;
    border-radius: 25px !important;
    padding: 12px 18px !important;
    font-size: 16px !important;
    transition: all 0.3s ease !important;
}

.textbox input:focus {
    border-color: #2c5530 !important;
    box-shadow: 0 0 10px rgba(44, 85, 48, 0.3) !important;
}

/* Button styling */
.btn-primary {
    background: linear-gradient(135deg, #4a90a4 0%, #2c5530 100%) !important;
    border: none !important;
    border-radius: 25px !important;
    padding: 12px 24px !important;
    color: white !important;
    font-weight: bold !important;
    transition: all 0.3s ease !important;
    box-shadow: 0 4px 10px rgba(74, 144, 164, 0.3) !important;
}

.btn-primary:hover {
    transform: translateY(-2px) !important;
    box-shadow: 0 6px 15px rgba(74, 144, 164, 0.4) !important;
}

.btn-secondary {
    background: linear-gradient(135deg, #dc3545 0%, #c82333 100%) !important;
    border: none !important;
    border-radius: 25px !important;
    padding: 12px 24px !important;
    color: white !important;
    font-weight: bold !important;
    transition: all 0.3s ease !important;
}

.btn-secondary:hover {
    transform: translateY(-2px) !important;
    box-shadow: 0 6px 15px rgba(220, 53, 69, 0.4) !important;
}

/* Sidebar buttons */
.sidebar-btn {
    background: linear-gradient(135deg, #ffffff 0%, #f8f9fa 100%) !important;
    border: 2px solid #4a90a4 !important;
    border-radius: 15px !important;
    padding: 10px 15px !important;
    margin: 5px 0 !important;
    color: #2c5530 !important;
    font-weight: 600 !important;
    transition: all 0.3s ease !important;
}

.sidebar-btn:hover {
    background: linear-gradient(135deg, #4a90a4 0%, #2c5530 100%) !important;
    color: white !important;
    transform: translateX(5px) !important;
}

/* Example question buttons */
.example-btn {
    background: linear-gradient(135deg, #e3f2fd 0%, #ffffff 100%) !important;
    border: 2px solid #4a90a4 !important;
    border-radius: 20px !important;
    padding: 10px 15px !important;
    margin: 5px !important;
    color: #2c5530 !important;
    font-weight: 500 !important;
    transition: all 0.3s ease !important;
    cursor: pointer !important;
}

.example-btn:hover {
    background: linear-gradient(135deg, #4a90a4 0%, #2c5530 100%) !important;
    color: white !important;
    transform: translateY(-2px) !important;
    box-shadow: 0 4px 10px rgba(74, 144, 164, 0.3) !important;
}

/* About section styling */
.about-section {
    background: rgba(255, 255, 255, 0.9) !important;
    border: 2px solid #4a90a4 !important;
    border-radius: 15px !important;
    padding: 15px !important;
    margin: 10px 0 !important;
    box-shadow: 0 2px 8px rgba(74, 144, 164, 0.2) !important;
}

/* Loading animation for "Thinking..." */
@keyframes pulse {
    0% { opacity: 0.6; }
    50% { opacity: 1; }
    100% { opacity: 0.6; }
}

.thinking {
    animation: pulse 1.5s infinite;
    color: #4a90a4 !important;
    font-style: italic !important;
}
"""

# Create the interface
with gr.Blocks(css=medical_css, theme=gr.themes.Soft(primary_hue="teal", secondary_hue="green")) as demo:

    # Header
    gr.Markdown("""
    <div style="text-align: center; padding: 15px;">
        <h1 style="color: #2c5530; font-size: 2em; margin-bottom: 8px;">🏥 Medical Knowledge Chatbot</h1>
        <p style="color: #4a90a4; font-size: 16px; font-weight: 500;">
            Your AI assistant for understanding <b>health</b> and <b>medicine</b>
        </p>
        <div style="width: 80px; height: 2px; background: linear-gradient(90deg, #4a90a4, #2c5530); margin: 15px auto; border-radius: 2px;"></div>
    </div>
    """)

    with gr.Row():
        # Main chat area
        with gr.Column(scale=3):
            chatbot_ui = gr.Chatbot(
                label="💬 Medical Assistant Chat",
                height=400,
                show_copy_button=True,
                bubble_full_width=False,
                avatar_images=["👤", "🩺"]
            )

            msg = gr.Textbox(
                placeholder="Type your medical question here and press Enter...",
                label="💬 Ask Your Question",
                lines=1,
                max_lines=2,
                container=True
            )

            with gr.Row():
                submit_btn = gr.Button("🚀 Send Message", variant="primary", scale=2)
                clear_btn = gr.Button("🧹 Clear Chat", variant="secondary", scale=1)

        # Sidebar
        with gr.Column(scale=1):
            gr.Markdown("""
            <div class="about-section">
                <h3 style="color: #2c5530; margin-bottom: 15px;">📚 Medical Topics</h3>
            </div>
            """)

            topic_btns = []
            topics = [
                ("🩺 Symptoms", "symptoms"),
                ("🦠 Diseases", "diseases"),
                ("💊 Treatments", "treatments"),
                ("💉 Medications", "medications"),
                ("🧠 Mental Health", "mental health")
            ]

            for topic_name, topic_key in topics:
                btn = gr.Button(topic_name, elem_classes="sidebar-btn")
                topic_btns.append(btn)

            gr.Markdown("""
            <div class="about-section">
                <h3 style="color: #2c5530; margin-bottom: 10px;">ℹ️ About</h3>
                <p style="color: #555; font-size: 14px; line-height: 1.5;">
                    This chatbot uses a fine-tuned <b>T5 model</b> to provide medical information.<br><br>
                    <strong>⚠️ Important:</strong> This is for educational purposes only and should not replace professional medical advice.
                </p>
            </div>
            """)

    # Example questions section
    gr.Markdown("""
    <div style="text-align: center; margin: 20px 0 15px 0;">
        <h3 style="color: #2c5530; font-size: 1.3em;">💡 Try These Example Questions</h3>
        <p style="color: #666; margin-bottom: 15px; font-size: 14px;">Click on any question to get started</p>
    </div>
    """)

    example_questions = [
        "What are the symptoms of diabetes?",
        "How is hypertension treated?",
        "What is bipolar disorder?",
        "What medicine is used for asthma?",
        "Is COVID-19 contagious?"
    ]

    with gr.Row():
        example_btns = []
        for question in example_questions:
            btn = gr.Button(question, elem_classes="example-btn", size="sm")
            example_btns.append(btn)

    # Event handlers
    def submit_message(message, history):
        return chatbot_respond(message, history)

    # Handle regular message submission
    submit_btn.click(
        fn=submit_message,
        inputs=[msg, chatbot_ui],
        outputs=[chatbot_ui, msg]
    )

    # Handle Enter key press
    msg.submit(
        fn=submit_message,
        inputs=[msg, chatbot_ui],
        outputs=[chatbot_ui, msg]
    )

    # Handle example question clicks
    for i, btn in enumerate(example_btns):
        btn.click(
            fn=handle_example_question,
            inputs=[gr.State(example_questions[i]), chatbot_ui],
            outputs=[chatbot_ui]
        )

    # Handle clear button
    clear_btn.click(
        fn=clear_chat,
        outputs=[chatbot_ui]
    )

    # Footer
    gr.Markdown("""
    <div style="text-align: center; margin-top: 20px; padding: 15px; background: rgba(255,255,255,0.7); border-radius: 15px;">
        <p style="color: #666; font-size: 13px;">
            🏥 <strong>Medical Knowledge Chatbot</strong> | Powered by AI for Educational Purposes
        </p>
    </div>
    """)

# Launch the app
if __name__ == "__main__":
    demo.launch(debug=True, share=True)

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at /content/drive/MyDrive/healthcare-chatbot-model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.
  chatbot_ui = gr.Chatbot(
  chatbot_ui = gr.Chatbot(


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://d6480abc73d1e5e288.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://d6480abc73d1e5e288.gradio.live
