# TinyLlama Medical Chatbot - Inference Only
# This notebook loads your fine-tuned adapter and provides a chatbot interface.


In [1]:
# Install required packages
!pip install -q transformers accelerate peft bitsandbytes torch

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m44.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
!pip install -q Flask pyngrok translate langdetect guardrails-ai alt-profanity-check rstr detoxify

In [4]:
# verify guardrails installed
!python -c "import guardrails; print('guardrails version ->', getattr(guardrails, '__version__', 'unknown'))"


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m758.6/758.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.7/307.7 kB[0m [31m32.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m122.5 MB/s[0m eta [36m0:00:00[0m
[?25h[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
guardrails version -> unknown


In [6]:
# Configure guardrails CLI (interactive)
!guardrails configure


Enable anonymous metrics reporting? [Y/n]: Y
Do you wish to use remote inferencing? [Y/n]: Y

[1mEnter API Key below[0m[1m [0m👉 You can find your API Key at [4;94mhttps://hub.guardrailsai.com/keys[0m

API Key: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJnaXRodWJ8MjEzNTI5ODA1IiwiYXBpS2V5SWQiOiJhNDllNjE0Mi0xNjI2LTRiNDQtODdiZi03NjIwNjU5MjY4ZWYiLCJzY29wZSI6InJlYWQ6cGFja2FnZXMiLCJwZXJtaXNzaW9ucyI6W10sImlhdCI6MTc2MTUyNzI0MSwiZXhwIjoxNzY5MzAzMjQxfQ.TRB_KXHUDJ7wck2ApGfd41cf2e96i9TrRaM77aSgLNo

            Login successful.

            Get started by installing our RegexMatch validator:
            https://hub.guardrailsai.com/validator/guardrails_ai/regex_match

            You can install it by running:
            guardrails hub install hub://guardrails/regex_match

            Find more validators at https://hub.guardrailsai.com
            


In [15]:
# Install specific validators from the Guardrails Hub
!guardrails hub install hub://guardrails/profanity_free
!guardrails hub install hub://guardrails/toxic_language
!guardrails hub install hub://guardrails/regex_match

Installing hub:[35m/[0m[35m/guardrails/[0m[95mprofanity_free...[0m
[2K[32m[=== ][0m Fetching manifest
[2K[32m[=   ][0m Downloading dependencies
[1A[2KTraceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/jsonschema/_format.py", line 304, in <module>
ModuleNotFoundError: No module named 'rfc3987'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/bin/guardrails", line 5, in <module>
    from guardrails.cli import cli
  File "/usr/local/lib/python3.12/dist-packages/guardrails/__init__.py", line 3, in <module>
    from guardrails.guard import Guard
  File "/usr/local/lib/python3.12/dist-packages/guardrails/guard.py", line 54, in <module>
    from guardrails.run import Runner, StreamRunner
  File "/usr/local/lib/python3.12/dist-packages/guardrails/run/__init__.py", line 1, in <module>
    from guardrails.run.async_runner import AsyncRunner
  File "/usr/local/lib/python3.12/di

In [8]:
# Uncomment and run this first if your adapter is zipped
# !unzip /content/med_lora_chat_adapter_zip.zip -d /content/med_lora_chat_adapter/

In [6]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import gradio as gr
from typing import List, Tuple
import os
from guardrails import Guard

In [7]:
# Configuration
MODEL_NAME = "tinyllama/TinyLlama-1.1B-Chat-v1.0"
ADAPTER_PATH = "/content/med_lora_chat_adapter"  # Path to your adapter.zip extracted folder
MAX_TOKENS = 512
USE_4BIT = True

In [8]:
def load_fine_tuned_model():
    """Load the base model and your fine-tuned adapter"""
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    # Add padding token if missing
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print("Loading base model...")
    model_kwargs = {
        "torch_dtype": torch.float16,
        "device_map": "auto",
        "trust_remote_code": True,
    }

    if USE_4BIT:
        model_kwargs.update({
            "quantization_config": {
                "load_in_4bit": True,
                "bnb_4bit_compute_dtype": torch.float16,
                "bnb_4bit_quant_type": "nf4",
            }
        })

    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        **model_kwargs
    )

    print("Loading your fine-tuned adapter...")
    if os.path.exists(ADAPTER_PATH):
        model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
        print("Adapter loaded successfully!")
    else:
        print("Adapter not found at:", ADAPTER_PATH)
        print("Using base model without fine-tuning...")
        model = base_model

    return model, tokenizer

In [9]:
import torch
from typing import List, Tuple
from flask import Flask
from guardrails import Guard
from guardrails.hub import ProfanityFree, ToxicLanguage, RegexMatch
from langdetect import detect
from translate import Translator

MAX_TOKENS = 512

# Create Guard with hub validators
guard = Guard().use_many(
    ProfanityFree(),       # default on_fail="noop" (or set to "exception" if you want a hard fail)
    ToxicLanguage(),       # same
    RegexMatch(r".*(consult|doctor|medical professional|physician).*")
)

class MedicalChatbot:
    def __init__(self):
        # Use the load_fine_tuned_model defined earlier
        self.model, self.tokenizer = load_fine_tuned_model()
        self.chat_history = []
        # Initialize translator with a default to_lang, it will be updated in translate_text
        self.translator = Translator(to_lang="en")

    def detect_language(self, text: str) -> str:
        """Detect the language of a given text string."""
        try:
            return detect(text)
        except:
            return 'en' # Default to English if detection fails

    def translate_text(self, text: str, dest_lang: str, src_lang: str) -> str:
        """Translate a given text string from a source language to a target language."""
        try:
            # Update the translator's destination language and source language
            self.translator = Translator(to_lang=dest_lang, from_lang=src_lang)
            translated_text = self.translator.translate(text)
            return translated_text
        except Exception as e:
            print(f"Translation error: {e}")
            return text # Return original text on failure


    def format_chat_prompt(self, message: str, history: List[Tuple[str, str]] = None):
        messages = []
        if history:
            for user_msg, assistant_msg in history:
                messages.append({"role":"user","content":user_msg})
                if assistant_msg:
                    messages.append({"role":"assistant","content":assistant_msg})
        messages.append({"role":"user","content":message})

        if hasattr(self.tokenizer, "apply_chat_template"):
            prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        else:
            prompt = "\n".join([
                f"User: {m['content']}" if m['role']=='user' else f"Assistant: {m['content']}"
                for m in messages
            ]) + "\nAssistant: "
        return prompt


    def generate_response(self, message: str, history: List[Tuple[str, str]] = None,
                          temperature: float = 0.7, top_p: float = 0.9):
        try:
            # 1. Detect the language of the incoming message
            original_lang = self.detect_language(message)
            print(f"Detected language: {original_lang}")

            llm_input_message = message
            # 2. If the detected language is 'ar', translate to English
            if original_lang == 'ar':
                llm_input_message = self.translate_text(message, dest_lang='en', src_lang='ar')
                print(f"Translated Arabic to English for LLM: {llm_input_message}")

            # Use the (potentially translated) message for the LLM prompt
            prompt = self.format_chat_prompt(llm_input_message, history)

            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=MAX_TOKENS,
                    temperature=temperature,
                    do_sample=True,
                    top_p=top_p,
                    pad_token_id=getattr(self.tokenizer, "pad_token_id", None),
                    eos_token_id=getattr(self.tokenizer, "eos_token_id", None),
                    repetition_penalty=1.1,
                )

            response_ids = outputs[0][inputs['input_ids'].shape[1]:]
            response = self.tokenizer.decode(response_ids, skip_special_tokens=True).strip()
            print(f"LLM generated English response: {response}")

            # --- VALIDATION STEP (robust) ---
            try:
                outcome = guard.validate(response)
            except Exception:
                # If guard.validate itself crashes, fallback to blocking the response
                final_response = "Response blocked or filtered for safety."
                print(f"Validation failed, returning safety message: {final_response}")
                # Translate safety message back if original was Arabic
                if original_lang == 'ar':
                    final_response = self.translate_text(final_response, dest_lang='ar', src_lang='en')
                    print(f"Translated safety message back to Arabic: {final_response}")
                return final_response

            validation_passed = getattr(outcome, "validation_passed", True)
            validated_output = getattr(outcome, "validated_output", None)

            final_response = response # Default to the LLM's raw response

            if not validation_passed:
                print("Validation failed.")
                replacement_text_for_medical = "Please consult a qualified doctor for accurate medical advice."
                # Simple heuristic: if the original response mentioned "doctor" or "consult", return the medical fallback
                if any(tok in response.lower() for tok in ("doctor", "consult", "medical professional", "physician")):
                    final_response = replacement_text_for_medical
                    print(f"Applying medical fallback due to validation failure: {final_response}")
                elif validated_output:
                    # some validators may produce a 'fixed' output in validated_output even if validation_passed is False
                    final_response = validated_output
                    print(f"Using validated output after failure: {final_response}")
                else:
                    final_response = "Response blocked or filtered for safety."
                    print(f"Applying generic safety message after validation failure: {final_response}")
            else:
                 final_response = validated_output if validated_output is not None else response
                 print(f"Validation passed, using final response: {final_response}")


            # 5. If the original language was 'ar', translate the response back to Arabic
            if original_lang == 'ar':
                final_response = self.translate_text(final_response, dest_lang='ar', src_lang='en')
                print(f"Translated English response back to Arabic: {final_response}")

            # 6. Return the final response (translated if necessary)
            return final_response

        except Exception as e:
            print(f"An error occurred during response generation: {e}")
            # Translate error message back if original was Arabic
            error_message = f"Error generating response: {str(e)}"
            if original_lang == 'ar':
                 error_message = self.translate_text(error_message, dest_lang='ar', src_lang='en')
            return error_message

# Re-initialize chatbot_instance after the class definition is updated
chatbot_instance = MedicalChatbot()
app = Flask(__name__)


Loading tokenizer...
Loading base model...
Loading your fine-tuned adapter...
Adapter not found at: /content/med_lora_chat_adapter
Using base model without fine-tuning...


In [10]:
from flask import request, jsonify

@app.route('/chat', methods=['POST'])
def chat():
    user_message = request.json.get('message')
    if not user_message:
        return jsonify({"error": "No message provided"}), 400

    # The generate_response method now handles translation internally
    response = chatbot_instance.generate_response(user_message)
    return jsonify({"response": response})

In [11]:
from pyngrok import ngrok
import threading
import time
from google.colab import userdata

# Set the Flask app to run on port 4998
port = 4998

# Get ngrok authtoken from Colab secrets
ngrok_auth_token = userdata.get('ngrok')
if ngrok_auth_token:
    ngrok.set_auth_token(ngrok_auth_token)
    print("ngrok authtoken set.")
else:
    print("NGROK_AUTH_TOKEN secret not found. Please add it to Colab secrets.")


# Start ngrok tunnel
ngrok_tunnel = ngrok.connect(port)
print(f" * ngrok tunnel established at: {ngrok_tunnel.public_url}")

# Function to run the Flask app
def run_flask_app():
    app.run(port=port, use_reloader=False)

# Run Flask app in a separate thread
thread = threading.Thread(target=run_flask_app)
thread.start()

# Keep the main thread alive to keep the Flask server and ngrok tunnel running
try:
    while True:
        time.sleep(1)
except KeyboardInterrupt:
    print("Shutting down...")
    ngrok.kill()

ngrok authtoken set.
 * ngrok tunnel established at: https://mellie-transitional-leonardo.ngrok-free.dev
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:4998
INFO:werkzeug:[33mPress CTRL+C to quit[0m


Shutting down...
