# TinyLlama Medical Chatbot - Inference Only
# This notebook loads your fine-tuned adapter and provides a chatbot interface.


In [None]:
# Install required packages
!pip install -q transformers accelerate peft bitsandbytes torch

In [None]:
!pip install -q Flask pyngrok

In [None]:
!pip install -q guardrails-ai

In [None]:
# Cell 1: Install Guardrails CLI + common validator deps
!pip install -q guardrails-ai alt-profanity-check rstr detoxify

# verify guardrails installed
!python -c "import guardrails; print('guardrails version ->', getattr(guardrails, '__version__', 'unknown'))"


guardrails version -> unknown


In [None]:
# installs guardrails + common validator deps
!pip install -q guardrails-ai alt-profanity-check rstr detoxify
# quick check where guardrails is installed
!python -c "import guardrails, sys, inspect; print('guardrails:', guardrails.__version__, inspect.getfile(guardrails))"


Traceback (most recent call last):
  File "<string>", line 1, in <module>
AttributeError: module 'guardrails' has no attribute '__version__'


In [None]:
# Cell 2: Configure guardrails CLI (interactive)
!guardrails configure


Enable anonymous metrics reporting? [Y/n]: Y
Do you wish to use remote inferencing? [Y/n]: Y

[1mEnter API Key below[0m[1m [0m[2;3mleave empty if you want to keep existing token[0m[3m [0m
👉 You can find your API Key at [4;94mhttps://hub.guardrailsai.com/keys[0m

API Key: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJnaXRodWJ8MjEzNTI5ODA1IiwiYXBpS2V5SWQiOiI1NGExYTJmNS00NDc4LTRjYTEtYWQ2YS05YTc1ZmMzNmFkYzciLCJzY29wZSI6InJlYWQ6cGFja2FnZXMiLCJwZXJtaXNzaW9ucyI6W10sImlhdCI6MTc2MTQ3OTY4NSwiZXhwIjoxNzY5MjU1Njg1fQ.xfxramcD6yvNsV0qECQxLyzTldVhEEOdEnfAfkzm3v4
SUCCESS:guardrails-cli:
            Login successful.

            Get started by installing our RegexMatch validator:
            https://hub.guardrailsai.com/validator/guardrails_ai/regex_match

            You can install it by running:
            guardrails hub install hub://guardrails/regex_match

            Find more validators at https://hub.guardrailsai.com
            


In [None]:
# Cell 3: Install specific validators from the Guardrails Hub
!guardrails hub install hub://guardrails/profanity_free
!guardrails hub install hub://guardrails/toxic_language
!guardrails hub install hub://guardrails/regex_match

Installing hub:[35m/[0m[35m/guardrails/[0m[95mprofanity_free...[0m
[2K[32m[==  ][0m Fetching manifest
[2K[32m[    ][0m Downloading dependencies
[2K[32m[==  ][0m Running post-install setup
[1A[2K✅Successfully installed guardrails/profanity_free version [1;36m0.0[0m.[1;36m0[0m!


[1mImport validator:[0m
from guardrails.hub import ProfanityFree

[1mGet more info:[0m
[4;94mhttps://hub.guardrailsai.com/validator/guardrails/profanity_free[0m

Installing hub:[35m/[0m[35m/guardrails/[0m[95mtoxic_language...[0m
[2K[32m[==  ][0m Fetching manifest
[2K[32m[   =][0m Downloading dependencies
[2K[32m[ ===][0m Running post-install setup2025-10-26 11:55:22.297895: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-26 11:55:22.315683: E extern

In [None]:
# Uncomment and run this first if your adapter is zipped
# !unzip /content/med_lora_chat_adapter_zip.zip -d /content/med_lora_chat_adapter/

Archive:  /content/med_lora_chat_adapter_zip.zip
replace /content/med_lora_chat_adapter/adapter_config.json? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import gradio as gr
from typing import List, Tuple
import os
from guardrails import Guard

In [None]:
# Configuration
MODEL_NAME = "tinyllama/TinyLlama-1.1B-Chat-v1.0"
ADAPTER_PATH = "/content/med_lora_chat_adapter"  # Path to your adapter.zip extracted folder
MAX_TOKENS = 512
USE_4BIT = True

In [None]:
def load_fine_tuned_model():
    """Load the base model and your fine-tuned adapter"""
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    # Add padding token if missing
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print("Loading base model...")
    model_kwargs = {
        "torch_dtype": torch.float16,
        "device_map": "auto",
        "trust_remote_code": True,
    }

    if USE_4BIT:
        model_kwargs.update({
            "quantization_config": {
                "load_in_4bit": True,
                "bnb_4bit_compute_dtype": torch.float16,
                "bnb_4bit_quant_type": "nf4",
            }
        })

    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        **model_kwargs
    )

    print("Loading your fine-tuned adapter...")
    if os.path.exists(ADAPTER_PATH):
        model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
        print("Adapter loaded successfully!")
    else:
        print("Adapter not found at:", ADAPTER_PATH)
        print("Using base model without fine-tuning...")
        model = base_model

    return model, tokenizer

In [None]:
# --- Robust fix: don't use on_fail="replace" (invalid) ---
# Use on_fail defaults (noop) and apply replacement logic after validation.

import torch
from typing import List, Tuple
from flask import Flask
from guardrails import Guard
from guardrails.hub import ProfanityFree, ToxicLanguage, RegexMatch

MAX_TOKENS = 512

# Create Guard with hub validators but DO NOT pass on_fail="replace"
# (either omit on_fail or use a supported keyword like "filter" / "fix" / "exception")
guard = Guard().use_many(
    ProfanityFree(),       # default on_fail="noop" (or set to "exception" if you want a hard fail)
    ToxicLanguage(),       # same
    RegexMatch(r".*(consult|doctor|medical professional|physician).*")  # keep simple here
)

# Removed the unimplemented function, will use the one defined earlier.

class MedicalChatbot:
    def __init__(self):
        # Use the load_fine_tuned_model defined earlier
        self.model, self.tokenizer = load_fine_tuned_model()
        self.chat_history = []

    def format_chat_prompt(self, message: str, history: List[Tuple[str, str]] = None):
        messages = []
        if history:
            for user_msg, assistant_msg in history:
                messages.append({"role":"user","content":user_msg})
                if assistant_msg:
                    messages.append({"role":"assistant","content":assistant_msg})
        messages.append({"role":"user","content":message})

        if hasattr(self.tokenizer, "apply_chat_template"):
            prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        else:
            prompt = "\n".join([
                f"User: {m['content']}" if m['role']=='user' else f"Assistant: {m['content']}"
                for m in messages
            ]) + "\nAssistant: "
        return prompt

    def generate_response(self, message: str, history: List[Tuple[str, str]] = None,
                          temperature: float = 0.7, top_p: float = 0.9):
        try:
            prompt = self.format_chat_prompt(message, history)

            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=MAX_TOKENS,
                    temperature=temperature,
                    do_sample=True,
                    top_p=top_p,
                    pad_token_id=getattr(self.tokenizer, "pad_token_id", None),
                    eos_token_id=getattr(self.tokenizer, "eos_token_id", None),
                    repetition_penalty=1.1,
                )

            response_ids = outputs[0][inputs['input_ids'].shape[1]:]
            response = self.tokenizer.decode(response_ids, skip_special_tokens=True).strip()

            # --- VALIDATION STEP (robust) ---
            # Use guard.validate(...) (preferred API) and then inspect outcome.
            try:
                outcome = guard.validate(response)
            except Exception:
                # If guard.validate itself crashes, fallback to blocking the response
                return "Response blocked or filtered for safety."

            # outcome likely has attributes: validation_passed (bool), validated_output (str), error / errors
            validation_passed = getattr(outcome, "validation_passed", True)
            validated_output = getattr(outcome, "validated_output", None)
            # If validation passed and validated_output exists, return it.
            if validation_passed:
                return validated_output if validated_output is not None else response

            # If validation failed, inspect errors and decide how to handle them.
            errors = getattr(outcome, "error", None) or getattr(outcome, "errors", None)
            # If the regex / medical mention validator failed, return a custom safe message:
            # (you can customize mapping from validator name -> replacement text here)
            replacement_text_for_medical = "Please consult a qualified doctor for accurate medical advice."
            # Simple heuristic: if the original response mentioned "doctor" or "consult", return the medical fallback
            if any(tok in response.lower() for tok in ("doctor", "consult", "medical professional", "physician")):
                return replacement_text_for_medical

            # Otherwise fallback to a generic safe message (or use validated_output if it contains fix suggestions)
            if validated_output:
                # some validators may produce a 'fixed' output in validated_output even if validation_passed is False
                return validated_output
            return "Response blocked or filtered for safety."

        except Exception as e:
            return f"Error generating response: {str(e)}"

# Flask app (do not run a persistent server inside Colab unless you want to)
app = Flask(__name__)
# Initialize chatbot_instance after the class definition
chatbot_instance = MedicalChatbot()

Loading tokenizer...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading base model...
Loading your fine-tuned adapter...
Adapter loaded successfully!


In [None]:
from flask import request, jsonify

@app.route('/chat', methods=['POST'])
def chat():
    user_message = request.json.get('message')
    if not user_message:
        return jsonify({"error": "No message provided"}), 400

    response = chatbot_instance.generate_response(user_message)
    return jsonify({"response": response})

In [None]:
from pyngrok import ngrok
import threading
import time
from google.colab import userdata

# Set the Flask app to run on port 4998
port = 4998

# Get ngrok authtoken from Colab secrets
ngrok_auth_token = userdata.get('ngrok')
if ngrok_auth_token:
    ngrok.set_auth_token(ngrok_auth_token)
    print("ngrok authtoken set.")
else:
    print("NGROK_AUTH_TOKEN secret not found. Please add it to Colab secrets.")


# Start ngrok tunnel
ngrok_tunnel = ngrok.connect(port)
print(f" * ngrok tunnel established at: {ngrok_tunnel.public_url}")

# Function to run the Flask app
def run_flask_app():
    app.run(port=port, use_reloader=False)

# Run Flask app in a separate thread
thread = threading.Thread(target=run_flask_app)
thread.start()

# Keep the main thread alive to keep the Flask server and ngrok tunnel running
try:
    while True:
        time.sleep(1)
except KeyboardInterrupt:
    print("Shutting down...")
    ngrok.kill()

ngrok authtoken set.
 * ngrok tunnel established at: https://mellie-transitional-leonardo.ngrok-free.dev
 * Serving Flask app '__main__'
 * Debug mode: off


Address already in use
Port 4998 is in use by another program. Either identify and stop that program, or start the server with a different port.
INFO:werkzeug:127.0.0.1 - - [26/Oct/2025 12:06:22] "POST /chat HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [26/Oct/2025 12:06:50] "POST /chat HTTP/1.1" 200 -


Shutting down...
