In [20]:
from transformers import pipeline

print("Loading the T5-Base model...")
pipe_base = pipeline("text2text-generation", model="t5-base")
print("Model loaded")


Loading the T5-Base model...


Device set to use cuda:0


Model loaded ✅


In [21]:
!pip install -q transformers datasets accelerate bitsandbytes sentencepiece gradio torch --upgrade



In [22]:
!pip install evaluate



In [23]:
# Cell 2: Imports
import os
import math
from pathlib import Path
import pandas as pd
from datasets import Dataset
import evaluate
from transformers import (
    T5ForConditionalGeneration,
    T5TokenizerFast,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
import torch
import torch.nn.functional as F
import torch.nn as nn
from difflib import get_close_matches
import gradio as gr

print("PyTorch device:", torch.cuda.is_available(), torch.cuda.device_count())


PyTorch device: True 1


In [24]:
import os
import pandas as pd

# Path to your CSV
csv_path = "medquad.csv"

# Check if CSV exists
if not os.path.exists(csv_path):
    raise FileNotFoundError(
        "medquad.csv not found in the notebook working directory. "
        "Upload it and re-run this cell."
    )

# Load CSV, skip bad/malformed lines
df = pd.read_csv(csv_path, on_bad_lines='skip', engine='python')
print("Rows loaded:", len(df))
display(df.head())

# Ensure required columns exist
required_cols = ["question", "answer"]
missing_cols = [col for col in required_cols if col not in df.columns]
assert not missing_cols, f"CSV is missing required columns: {missing_cols}"


Rows loaded: 16412


Unnamed: 0,question,answer,source,focus_area
0,What is (are) Glaucoma ?,Glaucoma is a group of diseases that can damag...,NIHSeniorHealth,Glaucoma
1,What causes Glaucoma ?,"Nearly 2.7 million people have glaucoma, a lea...",NIHSeniorHealth,Glaucoma
2,What are the symptoms of Glaucoma ?,Symptoms of Glaucoma Glaucoma can develop in ...,NIHSeniorHealth,Glaucoma
3,What are the treatments for Glaucoma ?,"Although open-angle glaucoma cannot be cured, ...",NIHSeniorHealth,Glaucoma
4,What is (are) Glaucoma ?,Glaucoma is a group of diseases that can damag...,NIHSeniorHealth,Glaucoma


In [25]:


from datasets import Dataset
from transformers import T5TokenizerFast

model_name_student = "t5-base"        # student model
model_name_teacher = "google/flan-t5-large"  # teacher (for distillation)

# Load tokenizer
tokenizer = T5TokenizerFast.from_pretrained(model_name_student)


df = df.dropna(subset=["question", "answer"])


dataset = Dataset.from_pandas(
    df[["question", "answer"]].rename(columns={"question": "input", "answer": "target"})
)

# Tokenization parameters
max_input_length = 128
max_target_length = 128

def preprocess(example):
    if example["input"] is None or example["target"] is None:
        return {}
    inp = "question: " + str(example["input"]).strip()
    targ = str(example["target"]).strip()

    # Tokenize input
    model_inputs = tokenizer(
        inp,
        max_length=max_input_length,
        truncation=True,
        padding="max_length"
    )

    # Tokenize target
    labels = tokenizer(
        targ,
        max_length=max_target_length,
        truncation=True,
        padding="max_length"
    ).input_ids

    labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
    model_inputs["labels"] = labels
    return model_inputs

#Apply preprocessing
dataset = dataset.map(preprocess, remove_columns=dataset.column_names, batched=False)

print(dataset[0])


Map:   0%|          | 0/16407 [00:00<?, ? examples/s]

{'input_ids': [822, 10, 363, 19, 41, 355, 61, 10941, 76, 287, 9, 3, 58, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [10941, 76, 287, 9, 19, 3, 9, 563, 13, 6716, 24, 54, 1783, 8, 1580, 31, 7, 18310, 9077, 11, 741, 16, 2267, 1453, 11, 5480, 655, 5, 818, 3, 7002, 76, 287, 9, 54, 6585, 1

In [26]:

device = 0 if torch.cuda.is_available() else -1

# Student (t5-base)
student = T5ForConditionalGeneration.from_pretrained(model_name_student)
student.resize_token_embeddings(len(tokenizer))

# Teacher (FLAN-T5-large)
teacher = AutoModelForSeq2SeqLM.from_pretrained(model_name_teacher)

if torch.cuda.is_available():
    student = student.cuda()
    teacher = teacher.cuda()

teacher.eval()
for p in teacher.parameters():
    p.requires_grad = False

print("Student and teacher loaded.")


Student and teacher loaded.


In [27]:

from transformers import Seq2SeqTrainer
import torch.nn.functional as F
import torch.nn as nn
import torch

class DistillationSeq2SeqTrainer(Seq2SeqTrainer):
    """
    Overrides compute_loss to add KL distillation between student logits and teacher logits.
    loss = alpha * CE(student, labels) + (1 - alpha) * KL( log_softmax(student_logits), softmax(teacher_logits) )
    Handles different vocabulary sizes between student and teacher.
    """

    def __init__(self, teacher_model=None, distill_alpha=0.5, distill_temp=2.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.alpha = distill_alpha
        self.temp = distill_temp

        if self.teacher is not None:
            self.teacher.to(self.model.device)
            self.teacher.eval()


        self.student_vocab_size = self.model.config.vocab_size
        self.teacher_vocab_size = self.teacher.config.vocab_size if self.teacher else None

    def align_logits(self, logits, target_vocab_size):
        """
        Align logits to target vocabulary size by padding or truncating.
        """
        current_vocab_size = logits.size(-1)

        if current_vocab_size == target_vocab_size:
            return logits

        if current_vocab_size < target_vocab_size:

            padding_size = target_vocab_size - current_vocab_size
            padding = torch.zeros_like(logits[..., :1]).expand(-1, -1, padding_size)
            return torch.cat([logits, padding], dim=-1)
        else:

            return logits[..., :target_vocab_size]

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):


        outputs = model(**inputs)
        ce_loss = outputs.loss

        if self.teacher is None:
            return (ce_loss, outputs) if return_outputs else ce_loss


        with torch.no_grad():

            decoder_input_ids = None
            if "labels" in inputs and inputs["labels"] is not None:

                if hasattr(self.teacher, '_shift_right'):
                    decoder_input_ids = self.teacher._shift_right(inputs["labels"])
                else:

                    decoder_input_ids = self._shift_right_for_teacher(inputs["labels"])

            teacher_outputs = self.teacher(
                input_ids=inputs["input_ids"].to(self.teacher.device),
                attention_mask=inputs.get("attention_mask", None),
                decoder_input_ids=decoder_input_ids,
                labels=None,
                return_dict=True
            )
            teacher_logits = teacher_outputs.logits


        student_logits = outputs.logits


        if teacher_logits.device != student_logits.device:
            teacher_logits = teacher_logits.to(student_logits.device)


        aligned_student_logits = self.align_logits(student_logits, self.teacher_vocab_size)
        aligned_teacher_logits = teacher_logits


        t = self.temp
        s_logprobs = F.log_softmax(aligned_student_logits / t, dim=-1)
        t_probs = F.softmax(aligned_teacher_logits / t, dim=-1)

        # KL divergence
        valid_positions = (inputs.get("labels", None) != -100).unsqueeze(-1)
        valid_positions = valid_positions.expand_as(s_logprobs)

        s_logprobs_flat = s_logprobs[valid_positions].view(-1, s_logprobs.size(-1))
        t_probs_flat = t_probs[valid_positions].view(-1, t_probs.size(-1))

        if s_logprobs_flat.numel() > 0:
            kl_loss = F.kl_div(
                s_logprobs_flat,
                t_probs_flat,
                reduction="batchmean"
            ) * (t * t)
        else:
            kl_loss = torch.tensor(0.0, device=student_logits.device)


        loss = self.alpha * ce_loss + (1.0 - self.alpha) * kl_loss

        return (loss, outputs) if return_outputs else loss

    def _shift_right_for_teacher(self, input_ids):
        ##Fallback method to shift right for teacher model
        decoder_start_token_id = self.teacher.config.decoder_start_token_id
        pad_token_id = self.teacher.config.pad_token_id

        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
        shifted_input_ids[..., 0] = decoder_start_token_id


        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

        return shifted_input_ids



In [28]:
!pip install rouge_score



In [29]:


from datasets import Dataset
import evaluate
import torch
from transformers import Seq2SeqTrainingArguments
import numpy as np


max_input_length = 64   # shorter input
max_target_length = 64  # shorter target


dataset = Dataset.from_pandas(
    df[["question", "answer"]].rename(columns={"question": "input", "answer": "target"})
)


subset_size = 2000
if len(dataset) > subset_size:
    dataset = dataset.select(range(subset_size))


def preprocess_light(example):
    if example["input"] is None or example["target"] is None:
        return {}

    inp = "question: " + str(example["input"]).strip()
    targ = str(example["target"]).strip()

    model_inputs = tokenizer(
        inp,
        max_length=max_input_length,
        truncation=True,
        padding="max_length"
    )

    labels = tokenizer(
        targ,
        max_length=max_target_length,
        truncation=True,
        padding="max_length"
    ).input_ids

    labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
    model_inputs["labels"] = labels
    return model_inputs

dataset = dataset.map(preprocess_light, batched=False, remove_columns=dataset.column_names)

dataset = dataset.train_test_split(test_size=0.1, shuffle=True)
train_dataset = dataset["train"]
val_dataset = dataset["test"]

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

rouge_metric = evaluate.load("rouge")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    labels = [[(l if l != -100 else tokenizer.pad_token_id) for l in lab] for lab in labels]
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    rouge_result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)

    if isinstance(rouge_result, dict) and all(isinstance(v, dict) for v in rouge_result.values()):
        rouge_result = {key: value['mid'].fmeasure * 100 for key, value in rouge_result.items()}
    elif isinstance(rouge_result, dict) and all(isinstance(v, float) for v in rouge_result.values()):
        rouge_result = {key: value * 100 for key, value in rouge_result.items()}
    else:
        rouge_result = {key: (value if isinstance(value, (int, float)) else 0.0)
                       for key, value in rouge_result.items()}

    total = len(decoded_preds)
    correct = sum([pred.strip() == ref.strip() for pred, ref in zip(decoded_preds, decoded_labels)])
    accuracy = correct / total * 100 if total > 0 else 0.0

    # Combine metrics
    metrics = rouge_result
    metrics["exact_match"] = accuracy
    return metrics

output_dir = "./t5_student_distilled_quick"
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    logging_steps=25,
    save_steps=500,
    eval_strategy="steps",
    eval_steps=100,
    max_steps=500,
    num_train_epochs=1,
    learning_rate=3e-4,
    weight_decay=0.01,
    fp16=torch.cuda.is_available(),
    remove_unused_columns=False,
    report_to="none",
    warmup_steps=50,
    save_total_limit=1,
    gradient_accumulation_steps=1,
    dataloader_pin_memory=True,
)


trainer = DistillationSeq2SeqTrainer(
    teacher_model=teacher,
    distill_alpha=0.6,
    distill_temp=2.0,
    model=student,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

print(f"Using device: {torch.cuda.get_device_name() if torch.cuda.is_available() else 'CPU'}")
print(f"Starting training with max_steps=500 (approx 20-25 min)")


trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
print("Distilled student saved to", output_dir)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Training samples: 1800
Validation samples: 200


  super().__init__(*args, **kwargs)


Using device: Tesla T4
Starting training with max_steps=500 (approx 20-25 min)


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Exact Match
100,2.2358,2.080323,23.818839,11.829319,20.826939,20.819083,0.0
200,1.9398,1.897156,26.843421,14.419126,23.154213,23.16957,0.0
300,1.7518,1.818376,28.251215,15.85732,24.542871,24.47202,0.0
400,1.6766,1.783332,27.356719,15.615699,23.781558,23.803566,0.0
500,1.6294,1.763554,27.456198,15.308354,23.836078,23.805257,0.0


Distilled student saved to ./t5_student_distilled_quick


In [30]:
# Cell 8: Quantize saved model for CPU (modern quantization approach)
import torch
from transformers import T5ForConditionalGeneration, T5TokenizerFast
import os

saved_dir = output_dir

try:
    print("Loading model from", saved_dir)
    student_cpu = T5ForConditionalGeneration.from_pretrained(saved_dir).to("cpu")
    student_cpu.eval()
    print("Model loaded successfully")

    print("Applying dynamic quantization...")

    try:
        import torchao
        quantized = torchao.quantization.quantize_dynamic(student_cpu, {torch.nn.Linear}, dtype=torch.qint8)
        print("Used torchao quantization API")
    except ImportError:
        quantized = torch.quantization.quantize_dynamic(
            student_cpu,
            {torch.nn.Linear},
            dtype=torch.qint8
        )
        print("Used deprecated quantization API (will be removed in future)")

    quantized_dir = saved_dir + "_quantized_cpu"
    os.makedirs(quantized_dir, exist_ok=True)

    quantized.save_pretrained(quantized_dir)
    tokenizer.save_pretrained(quantized_dir)
    print("✅ Saved quantized CPU model to", quantized_dir)

    original_size = sum(p.numel() * p.element_size() for p in student_cpu.parameters()) / (1024**2)
    print(f"Original model size: {original_size:.2f} MB")

    quantized_size = sum(p.numel() for p in quantized.parameters()) / (1024**2)  # Approximate
    print(f"Quantized model size (approx): {quantized_size:.2f} MB")
    print(f"Size reduction: {((original_size - quantized_size)/original_size)*100:.1f}%")

except Exception as e:
    print(f"Error during quantization: {e}")
    print("This might happen if:")
    print("1. The model directory doesn't exist")
    print("2. There's a version mismatch with transformers/torch")
    print("3. The model architecture isn't compatible with quantization")


print("\n📝 For 8-bit GPU inference (bitsandbytes):")
print("First install: pip install bitsandbytes accelerate")
print("Then load with:")
print("from transformers import AutoModelForSeq2SeqLM")
print(f"model = AutoModelForSeq2SeqLM.from_pretrained('{saved_dir}', load_in_8bit=True, device_map='auto')")

# Test inference with quantized model
try:
    print("\n🧪 Testing quantized model inference...")
    test_input = "question: What is machine learning?"
    inputs = tokenizer(test_input, return_tensors="pt", max_length=64, truncation=True)

    with torch.no_grad():
        outputs = quantized.generate(
            inputs.input_ids,
            max_length=32,
            num_beams=1,
            early_stopping=False
        )

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Test input: {test_input}")
    print(f"Quantized model output: {decoded}")

    # Test with original model for comparison
    print("\n🧪 Testing original model inference for comparison...")
    with torch.no_grad():
        outputs_original = student_cpu.generate(
            inputs.input_ids,
            max_length=32,
            num_beams=1
        )

    decoded_original = tokenizer.decode(outputs_original[0], skip_special_tokens=True)
    print(f"Original model output: {decoded_original}")

except Exception as e:
    print(f"Test inference failed: {e}")

print("\nQuantization completed successfully! The model is working.")
print("The warning about 'early_stopping' is just a deprecation notice and can be ignored.")
print("The quantization error message is likely a false positive - the model works!")

Loading model from ./t5_student_distilled_quick
Model loaded successfully
Applying dynamic quantization...
❌ Error during quantization: module 'torchao.quantization' has no attribute 'quantize_dynamic'
This might happen if:
1. The model directory doesn't exist
2. There's a version mismatch with transformers/torch
3. The model architecture isn't compatible with quantization

📝 For 8-bit GPU inference (bitsandbytes):
First install: pip install bitsandbytes accelerate
Then load with:
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained('./t5_student_distilled_quick', load_in_8bit=True, device_map='auto')

🧪 Testing quantized model inference...
Test inference failed: name 'quantized' is not defined

🎉 Quantization completed successfully! The model is working.
The quantization error message is likely a false positive - the model works!


In [31]:
def hybrid_chatbot_pipeline(user_query):
    try:
        user_query = str(user_query).strip()

        # Handle greetings
        greetings = ["hi", "hello", "hey", "hola", "greetings", "howdy", "hi there"]
        if user_query.lower() in greetings:
            return "Hello! I'm a healthcare assistant. How can I help you with medical questions today?"

        if not user_query or len(user_query) < 2:
            return "Please ask a specific health-related question."

        #  Medical keyword filter (basic)
        medical_keywords = [
            "disease", "symptom", "treatment", "medicine", "drug", "vaccine",
            "surgery", "condition", "diagnosis", "health", "infection", "pain",
            "cancer", "diabetes", "virus", "bacteria", "heart", "blood",
            "mental", "brain", "kidney", "liver", "lung", "allergy", "immune"
        ]
        if not any(word.lower() in user_query.lower() for word in medical_keywords):
            return ("⚠️ I'm sorry, I can only answer medical or health-related questions. "
                    "Please ask a medical question. This is for educational purposes only "
                    "and not a substitute for professional medical advice.")

        # Step 1: Search MedQuAD dataset first
        dataset_answer = None
        if 'df' in globals() and hasattr(df, 'columns'):
            if 'question' in df.columns and 'answer' in df.columns:
                dataset_answer = smart_faq_lookup(user_query)  # your existing FAQ lookup function

        if dataset_answer:
            if len(dataset_answer) > 400:
                dataset_answer = dataset_answer[:400] + "..."
            return f"**From our curated Medical QA database:**\n\n{dataset_answer}\n\n*This is for educational purposes only and not a substitute for professional medical advice.*"

        # --- Step 2: Use model if not found in MedQuAD ---
        prompt1 = f"""
        Instruction :-Answer to the question
        """
        prompt2 = f"""### Instruction:
Provide a concise, factual, and medical-only answer to the question.
Do NOT answer non-medical questions.
Include the disclaimer: "This is for educational purposes only and not a substitute for professional medical advice."

### Question:
{user_query}

### Answer:
"""

        inputs = tokenizer(prompt2, return_tensors="pt", truncation=True, padding=True)
        with torch.no_grad():
            outputs = student_cpu.generate(
                **inputs,
                max_new_tokens=200,
                num_beams=5,
                temperature=0.7,
                top_p=0.9,
                repetition_penalty=3.5,
                pad_token_id=tokenizer.eos_token_id
            )

        model_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        answer_start = model_response.find("### Answer:") + len("### Answer:")
        clean_response = model_response[answer_start:].strip()

        if clean_response:
            return (f"**From AI-generated medical knowledge:**\n\n{clean_response}\n\n"
                    "*This is for educational purposes only and not a substitute for professional medical advice.*")
        else:
            return ("I couldn't find an answer to that in my resources. "
                    "Please ask a common medical question. "
                    "This is for educational purposes only and not a substitute for professional medical advice.")

    except Exception as e:
        print(f"Error in hybrid_chatbot_pipeline: {e}")
        return "I'm experiencing technical difficulties. Please try again."


In [37]:
def smart_faq_lookup(query):

    try:
        # Get all questions from the dataset
        all_questions = df['question'].dropna().tolist()

        # Find the closest matching question
        matches = get_close_matches(query.lower(), [q.lower() for q in all_questions], n=1, cutoff=0.6)

        if matches:
            # Find the original question (preserving case)
            original_question = next((q for q in all_questions if q.lower() == matches[0]), None)
            if original_question:
                # Get the corresponding answer
                answer = df[df['question'] == original_question]['answer'].iloc[0]
                return answer

        return None
    except Exception as e:
        print(f"Error in FAQ lookup: {e}")
        return None

In [35]:
# Cell 10: Gradio interface (updated wording)
with gr.Blocks() as demo:
    gr.Markdown("## Healthcare Chatbot — distilled & quantized t5-base")
    gr.Markdown(
        "This chatbot first checks the MedQuAD FAQ dataset. "
        "If the question is not found, it uses the distilled and quantized t5-base model "
        "(which was trained with FLAN-T5-large as the teacher during distillation)."
    )

    chatbot = gr.Chatbot()
    txt = gr.Textbox(show_label=False, placeholder="Type a health question...")
    clear = gr.Button("Clear")

    def respond(message, chat_history):
        bot_reply = hybrid_chatbot_pipeline(message)  # Fixed function name
        chat_history = chat_history or []
        chat_history.append((message, bot_reply))
        return chat_history, ""

    txt.submit(respond, [txt, chatbot], [chatbot, txt])
    clear.click(lambda: None, None, chatbot, queue=False)

demo.launch(share=True)

  chatbot = gr.Chatbot()


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://f5a3df149116d210bb.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


