# ChatBud - Child-Friendly AI Assistant



## Cell 1: Install Dependencies

In [None]:
!pip install -q transformers>=4.50.0 accelerate>=1.0.0 peft>=0.12.0 bitsandbytes>=0.43.3
!pip install -q huggingface_hub sentencepiece
!pip install -q flask flask-cors Pillow

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    raise RuntimeError("‚ùå No GPU! Go to Runtime ‚Üí Change runtime type ‚Üí T4 GPU")

## Cell 2: Login to Hugging Face

In [None]:
import os
from huggingface_hub import login

try:
    from google.colab import userdata
    hf_token = userdata.get('HF_TOKEN')
except:
    hf_token = None

if not hf_token:
    hf_token = "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"

login(token=hf_token)
os.environ["HF_TOKEN"] = hf_token
print("‚úÖ Logged in to Hugging Face")

## Cell 3: Mount Google Drive

In [None]:
from google.colab import drive
import os

drive.mount("/content/drive")

ADAPTER_DIR = "/content/drive/MyDrive/gemma3_child_friendly_lora/gemma3_child_friendly_lora"

if os.path.exists(ADAPTER_DIR):
    print(f"‚úÖ Adapter found: {ADAPTER_DIR}")
else:
    raise FileNotFoundError(f"‚ùå Adapter not found: {ADAPTER_DIR}")

## Cell 4: Load Model

In [None]:
import torch
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig

MODEL_ID = "google/gemma-3-4b-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

print("Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_ID)
print("‚úÖ Processor loaded")

print("\nLoading model (2-3 min)...")
base_model = Gemma3ForConditionalGeneration.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
base_model.eval()
print("‚úÖ Base model loaded")

## Cell 5: Load LoRA + Test

In [None]:
from peft import PeftModel

print(f"Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
model.eval()
print("‚úÖ Fine-tuned model ready!")

# =============================================================================
# SYSTEM PROMPT (hardcoded for child safety)
# =============================================================================
SYSTEM_PROMPT = """You are ChatBud, a friendly and safe helper for children aged 9‚Äì11.
Speak with simple words (use the least number of words as possible) and short sentences (concise), like you're talking to a smart kid, and keep answers brief (about 1‚Äì4 short sentences as a maximum).
Never swear, use rude or sexual language, or describe violence, self-harm, or sex in graphic detail.
Do not give risky instructions, dares, or tips that could hurt someone in real life or online.
If a problem sounds serious or scary, tell the child to stop, stay safe, and talk to a trusted adult such as a parent, caregiver, teacher, or counselor."""

print(f"\nüìã System prompt loaded ({len(SYSTEM_PROMPT)} chars)")

# =============================================================================
# Test with system prompt
# =============================================================================
print("\nüß™ Testing model with system prompt...")
try:
    test_messages = [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
        {"role": "user", "content": [{"type": "text", "text": "What are elephants?"}]}
    ]
    
    test_inputs = processor.apply_chat_template(
        test_messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,
    ).to("cuda")
    
    print(f"   Input tokens: {test_inputs['input_ids'].shape[-1]}")
    
    with torch.inference_mode():
        test_output = model.generate(
            **test_inputs,
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )
    
    response = processor.decode(test_output[0][test_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
    print(f"   ‚úÖ Response: {response[:100]}...")
except Exception as e:
    print(f"   ‚ùå Test failed: {e}")
    raise

print("\n" + "="*50)
print("‚úÖ Ready! Run Cell 6 to start server.")
print("="*50)

## Cell 6: Start Server (with Memory + Images)

**Copy the URL to your ChatBud UI!**

In [None]:
# =============================================================================
# CHATBUD SERVER
# - System prompt (child-safe)
# - Conversation history (8K context)
# - Image support
# =============================================================================

from flask import Flask, request, jsonify
from flask_cors import CORS
from PIL import Image
from io import BytesIO
import base64
import threading
import subprocess
import time
import re

!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared 2>/dev/null
!chmod +x cloudflared
!fuser -k 5001/tcp 2>/dev/null || true

app = Flask(__name__)
CORS(app)
PORT = 5001

# =============================================================================
# CONFIGURATION
# =============================================================================
MAX_CONTEXT_TOKENS = 8192  # Gemma 3 4B max context
MAX_NEW_TOKENS = 256       # Max response length
MAX_HISTORY_TURNS = 20     # Keep last N conversation turns

# Child-safe system prompt (HARDCODED - cannot be changed by users)
SYSTEM_PROMPT = """You are ChatBud, a friendly and safe helper for children aged 9‚Äì11.
Speak with simple words (use the least number of words as possible) and short sentences (concise), like you're talking to a smart kid, and keep answers brief (about 1‚Äì4 short sentences as a maximum).
Never swear, use rude or sexual language, or describe violence, self-harm, or sex in graphic detail.
Do not give risky instructions, dares, or tips that could hurt someone in real life or online.
If a problem sounds serious or scary, tell the child to stop, stay safe, and talk to a trusted adult such as a parent, caregiver, teacher, or counselor."""

# Store conversation history per session
conversations = {}


def decode_base64_image(base64_string):
    """Decode base64 image to PIL Image."""
    try:
        if ',' in base64_string:
            base64_string = base64_string.split(',')[1]
        
        image_bytes = base64.b64decode(base64_string)
        image = Image.open(BytesIO(image_bytes)).convert("RGB")
        
        # Resize large images
        max_size = 512
        if max(image.size) > max_size:
            ratio = max_size / max(image.size)
            new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
            image = image.resize(new_size, Image.LANCZOS)
        
        return image
    except Exception as e:
        print(f"Image decode error: {e}")
        return None


def build_messages(conversation_id, user_message, image=None):
    """
    Build message list with system prompt and conversation history.
    """
    # Get or create conversation history
    if conversation_id not in conversations:
        conversations[conversation_id] = []
    
    history = conversations[conversation_id]
    
    # Build messages list
    messages = []
    
    # 1. System prompt (always first)
    messages.append({
        "role": "system",
        "content": [{"type": "text", "text": SYSTEM_PROMPT}]
    })
    
    # 2. Add conversation history (limit to last N turns)
    recent_history = history[-MAX_HISTORY_TURNS:] if len(history) > MAX_HISTORY_TURNS else history
    for turn in recent_history:
        messages.append(turn)
    
    # 3. Add current user message
    current_content = []
    if image is not None:
        current_content.append({"type": "image", "image": image})
    
    text = user_message if user_message else "What do you see in this picture?"
    current_content.append({"type": "text", "text": text})
    
    current_user_msg = {"role": "user", "content": current_content}
    messages.append(current_user_msg)
    
    return messages, current_user_msg


def generate_response(conversation_id, user_message, image=None):
    """
    Generate response with full conversation context.
    """
    try:
        # Build messages with history
        messages, current_user_msg = build_messages(conversation_id, user_message, image)
        
        # Process inputs
        inputs = processor.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
            return_dict=True,
        ).to("cuda")
        
        input_length = inputs["input_ids"].shape[-1]
        print(f"   Context: {input_length} tokens")
        
        # Check if we're exceeding context limit
        if input_length > MAX_CONTEXT_TOKENS - MAX_NEW_TOKENS:
            print(f"   ‚ö†Ô∏è Context too long, trimming history...")
            # Remove oldest turns from history
            if conversation_id in conversations and len(conversations[conversation_id]) > 2:
                conversations[conversation_id] = conversations[conversation_id][-4:]
                # Rebuild messages
                messages, current_user_msg = build_messages(conversation_id, user_message, image)
                inputs = processor.apply_chat_template(
                    messages,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                    return_dict=True,
                ).to("cuda")
        
        # Generate
        with torch.inference_mode():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
            )
        
        # Decode response
        response = processor.decode(
            output_ids[0][inputs["input_ids"].shape[-1]:],
            skip_special_tokens=True
        ).strip()
        
        # Save to history (without image data to save memory)
        # Store user message
        user_history_msg = {
            "role": "user",
            "content": [{"type": "text", "text": user_message or "[sent an image]"}]
        }
        conversations[conversation_id].append(user_history_msg)
        
        # Store assistant response
        assistant_msg = {
            "role": "assistant",
            "content": [{"type": "text", "text": response}]
        }
        conversations[conversation_id].append(assistant_msg)
        
        return response
        
    except Exception as e:
        print(f"Generation error: {e}")
        return "Oops! Something went wrong. Can you try again?"


@app.route('/api/chat', methods=['POST'])
def chat():
    """Handle chat messages."""
    try:
        data = request.json
        message = data.get('message', '').strip()
        image_data = data.get('image', None)
        conversation_id = data.get('conversation_id', 'default')
        
        # Decode image if provided
        image = None
        if image_data:
            print("üì∑ Image received")
            image = decode_base64_image(image_data)
        
        if not message and not image:
            return jsonify({'success': False, 'error': 'No message or image'}), 400
        
        # Log
        if image:
            print(f"üì© [IMAGE] + '{message[:30]}...'" if message else "üì© [IMAGE]")
        else:
            print(f"üì© {message[:50]}")
        
        # Generate with context
        response = generate_response(conversation_id, message, image)
        print(f"üì§ {response[:50]}...")
        
        return jsonify({'success': True, 'response': response})
        
    except Exception as e:
        print(f"‚ùå Error: {e}")
        return jsonify({'success': True, 'response': "Something went wrong. Try again?"})


@app.route('/api/clear', methods=['POST'])
def clear_history():
    """Clear conversation history."""
    try:
        data = request.json
        conversation_id = data.get('conversation_id', 'default')
        if conversation_id in conversations:
            conversations[conversation_id] = []
        print(f"üóëÔ∏è Cleared history for {conversation_id}")
        return jsonify({'success': True, 'message': 'History cleared'})
    except Exception as e:
        return jsonify({'success': False, 'error': str(e)})


@app.route('/api/health', methods=['GET'])
def health():
    return jsonify({
        'status': 'ok',
        'model': 'Gemma 3 4B-IT + LoRA',
        'features': ['text', 'images', 'memory'],
        'max_context': MAX_CONTEXT_TOKENS,
    })


def run_flask():
    app.run(host='0.0.0.0', port=PORT, use_reloader=False, threaded=True)


# =============================================================================
# START
# =============================================================================

print("üöÄ Starting ChatBud...")
print(f"   Context window: {MAX_CONTEXT_TOKENS} tokens")
print(f"   Max history: {MAX_HISTORY_TURNS} turns\n")

flask_thread = threading.Thread(target=run_flask, daemon=True)
flask_thread.start()
time.sleep(2)

print("üåê Starting tunnel...\n")

process = subprocess.Popen(
    ['./cloudflared', 'tunnel', '--url', f'http://localhost:{PORT}'],
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True
)

public_url = None
for _ in range(60):
    line = process.stdout.readline()
    if line and 'trycloudflare.com' in line:
        match = re.search(r'https://[a-zA-Z0-9-]+\.trycloudflare\.com', line)
        if match:
            public_url = match.group(0)
            break

if public_url:
    print("\n" + "=" * 70)
    print("üéâ CHATBUD READY!")
    print("=" * 70)
    print(f"\nüì° URL: {public_url}\n")
    print("‚úÖ Features:")
    print("   ‚Ä¢ Child-safe system prompt (hardcoded)")
    print("   ‚Ä¢ Conversation memory (remembers context)")
    print("   ‚Ä¢ Image understanding")
    print(f"   ‚Ä¢ {MAX_CONTEXT_TOKENS} token context window")
    print("\nüìã Paste URL into ChatBud UI settings (‚öôÔ∏è)")
    print("‚ö†Ô∏è  Keep this cell running!")
    print("=" * 70)
    
    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        print("\nüëã Bye!")
else:
    print("‚ùå Tunnel failed. Run cell again.")