In [None]:
!pip install -q datasets transformers accelerate peft bitsandbytes

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
import torch
import textwrap
import os

In [None]:
# -------------------------
# Config
# -------------------------
MODEL_NAME = "tinyllama/TinyLlama-1.1B-Chat-v1.0"
ADAPTER_OUT = "/content/med_lora_chat_adapter"
MAX_TOKENS = 512
USE_4BIT = True
NUM_TRAIN_ROWS = 5000
VAL_SIZE = 500
BATCH_SIZE = 1
GRAD_ACCUM = 16
EPOCHS = 3
LEARNING_RATE = 2e-4
OUTPUT_DIR = "/content/tinyllama_med_lora_run"

In [None]:
# -------------------------
# 1) Load dataset & show head
# -------------------------
print("Loading dataset from Hugging Face...")
ds = load_dataset("ruslanmv/ai-medical-chatbot")

print("\nAvailable splits:", ds.keys())
print("\nTrain split size (raw):", len(ds["train"]))

# Select only the first NUM_TRAIN_ROWS rows from train
train_raw = ds["train"].select(range(min(NUM_TRAIN_ROWS, len(ds["train"]))))
print(f"\nSelected first {len(train_raw)} rows for experiments (will use {NUM_TRAIN_ROWS - VAL_SIZE} train + {VAL_SIZE} val).")

# Print column names and first 5 rows (head)
print("\nColumn names:", train_raw.column_names)
print("\nFirst 5 rows (head):")
for i, ex in enumerate(train_raw.select(range(min(5, len(train_raw))))):
    print(f"\n--- Row {i} ---")
    for k, v in ex.items():
        print(f"{k}: {str(v)[:400]}")  # truncate long fields

In [None]:
# -------------------------
# 2) Convert to chat-style messages
# -------------------------
def to_chat(example):
    """
    Try to find user/patient text and assistant/doctor text from several possible column names.
    Fallback: if there's a single 'text' or 'dialog', try to split on newline.
    """
    # candidate keys (common variants)
    user_keys = ["Patient", "patient", "question", "Question", "prompt", "user", "User", "q", "Q", "text"]
    doc_keys  = ["Doctor", "doctor", "answer", "Answer", "response", "Response", "reply", "Reply", "a"]

    user_text = None
    doc_text = None

    for k in user_keys:
        if k in example and example[k]:
            user_text = example[k]
            break

    for k in doc_keys:
        if k in example and example[k]:
            doc_text = example[k]
            break

    # Handle case where dataset has single "text" with lines
    if (user_text is None or doc_text is None) and "text" in example and example["text"]:
        txt = example["text"]
        if isinstance(txt, str):
            parts = [p.strip() for p in txt.split("\n") if p.strip()]
            if len(parts) >= 2:
                user_text = user_text or parts[0]
                doc_text  = doc_text  or parts[1]
    # Last fallback: put everything into user and empty assistant
    user_text = (user_text or "").strip()
    doc_text  = (doc_text  or "").strip()

    # Ensure non-empty user (if empty, use a placeholder)
    if user_text == "":
        user_text = "Hello, I have a medical question."

    # If assistant reply missing, keep it short
    if doc_text == "":
        doc_text = "I need more information to answer that."

    return {"messages": [{"role": "user", "content": user_text}, {"role": "assistant", "content": doc_text}]}

print("\nConverting selected rows to chat 'messages' pairs (this will be fast for 5k rows)...")
chat_small = train_raw.map(to_chat)

# Quick sample check
print("\nSample converted message (index 0):")
print(chat_small[0]["messages"])

In [None]:
# -------------------------
# 3) Create train / val splits from the 5000 rows
# -------------------------
total_small = len(chat_small)
assert total_small == NUM_TRAIN_ROWS or total_small < NUM_TRAIN_ROWS
train_count = max(0, total_small - VAL_SIZE)
train_ds = chat_small.select(range(train_count))
val_ds   = chat_small.select(range(train_count, total_small))
print(f"\nTrain rows: {len(train_ds)}, Val rows: {len(val_ds)}")

# -------------------------
# 4) Tokenize using TinyLlama chat template
# -------------------------
print("\nLoading tokenizer and preparing tokenization...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
# ensure pad token exists
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_chat(example):
    # build the chat template string using TinyLlama helper
    prompt = tokenizer.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False)
    tokens = tokenizer(
        prompt,
        truncation=True,
        max_length=MAX_TOKENS,
        padding="max_length"
    )
    # labels are a copy of input_ids (causal LM)
    tokens["labels"] = tokens["input_ids"].copy()
    return tokens

print("Tokenizing train set (this may take a while)...")
train_tokenized = train_ds.map(tokenize_chat, remove_columns=train_ds.column_names)
print("Tokenizing validation set...")
val_tokenized = val_ds.map(tokenize_chat, remove_columns=val_ds.column_names)

# Set dataset format to torch tensors and keep only needed columns
cols = ["input_ids", "attention_mask", "labels"]
train_tokenized.set_format(type="torch", columns=cols)
val_tokenized.set_format(type="torch", columns=cols)
print("\nTokenization complete. Example shapes (train index 0):")
print({k: train_tokenized[0][k].shape for k in cols})

In [None]:
# -------------------------
# 5) Load base model and attach LoRA adapter
# -------------------------
print("\nLoading base model and preparing LoRA...")

model_kwargs = {"device_map": "auto"}
if USE_4BIT:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        load_in_4bit=True,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        device_map="auto"
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )

# LoRA config
peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, peft_config)

print("\nTrainable parameters (should be LoRA params only):")
try:
    model.print_trainable_parameters()
except Exception as e:
    # fallback listing if the wrapper doesn't provide the helper
    for n, p in model.named_parameters():
        if p.requires_grad:
            print(n, p.shape)

In [None]:
# -------------------------
# 6) TrainingArguments + Trainer
# -------------------------
print("\nPreparing Trainer...")

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_train_epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    fp16=True,
    logging_steps=50,
    save_total_limit=3,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    tokenizer=tokenizer,
)

In [None]:
# -------------------------
# 7) Launch training
# -------------------------

train_result = trainer.train()
print("\nTraining finished. Metrics:")
print(train_result)

# -------------------------
# 8) Save adapter & tokenizer
# -------------------------
print(f"\nSaving adapter + tokenizer to {ADAPTER_OUT} ...")
os.makedirs(ADAPTER_OUT, exist_ok=True)
model.save_pretrained(ADAPTER_OUT)
tokenizer.save_pretrained(ADAPTER_OUT)
print("Saved.")

In [None]:
# -------------------------
# 9) Interactive chat loop
# -------------------------
from peft import PeftModel
import torch
import textwrap

print("\nStarting interactive chat with the fine-tuned adapter. Type 'exit' to quit.\n")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    load_in_4bit=USE_4BIT,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)
model = PeftModel.from_pretrained(model, ADAPTER_OUT)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# helper to keep a bounded conversation history (prevent runaway token length)
def recent_messages(conv, max_user_assistant_pairs=6):
    # keep system + last N user/assistant pairs
    if len(conv) <= 1:
        return conv
    system = conv[0]
    pairs = conv[1:]
    # keep last 2*max_user_assistant_pairs items from pairs
    keep = pairs[-(2 * max_user_assistant_pairs):] if len(pairs) > (2 * max_user_assistant_pairs) else pairs
    return [system] + keep

# initialize conversation
conversation = [
    {"role": "system", "content": """You are Dr. AI, a professional, empathetic, and friendly medical assistant. Follow these rules:

1. Provide clear, concise, and easy-to-understand answers in a conversational style.
2. Offer practical and safe guidance whenever possible.
3. If unsure about a condition or if the situation could be serious, advise the user to consult a qualified healthcare professional.
4. Ask clarifying questions if needed and remember recent conversation context.
5. Be polite, empathetic, and avoid unnecessary medical jargon (or explain it simply).
6. always answer in simplest terms.

Always respond responsibly, keeping the user's safety and understanding in mind.
"""}
]

device = next(model.parameters()).device  # model device (works with device_map="auto")

try:
    while True:
        user_input = input("You: ").strip()
        if user_input.lower() in ["exit", "quit", "stop"]:
            print("Session ended.")
            break

        conversation.append({"role": "user", "content": user_input})

        # build a truncated recent conversation to avoid exceeding tokenizer max length
        conv_for_prompt = recent_messages(conversation, max_user_assistant_pairs=6)
        prompt = tokenizer.apply_chat_template(conv_for_prompt, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            generated = model.generate(
                **inputs,
                max_new_tokens=256,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        full_output = tokenizer.decode(generated[0], skip_special_tokens=True)
        # Trim the prompt from the decoded text to get only the assistant reply
        reply = full_output[len(prompt):].strip()
        if reply == "":
            reply = "I'm sorry — I couldn't generate a response. Could you rephrase?"

        print("\nDoctor:\n")
        print(textwrap.fill(reply, width=100))
        print()

        # append assistant response to conversation memory
        conversation.append({"role": "assistant", "content": reply})

except KeyboardInterrupt:
    print("\n\nInterrupted by user — session ended.")



In [None]:
import shutil
from google.colab import files

# Compress the folder into a zip file
shutil.make_archive("/content/med_lora_chat_adapter_zip", 'zip', "/content/med_lora_chat_adapter")
files.download("/content/med_lora_chat_adapter_zip.zip")