# Early Autism Screening Guidance Chatbot
## Fine-tuning TinyLlama 1.1B with LoRA

**Project:** Academic - AI for Healthcare  
**Target Users:** Caregivers, teachers, and community health workers  
**Purpose:** Raise awareness, provide guidance, and encourage professional screening — **NOT for diagnosis**

---
### What this notebook does:
1. Installs dependencies (transformers, peft, bitsandbytes, trl, etc.)
2. Loads TinyLlama 1.1B with 4-bit quantization
3. Loads and formats the instruction dataset (JSONL)
4. Applies LoRA fine-tuning (memory-efficient)
5. Trains with GPU-friendly settings
6. Saves the fine-tuned model
7. Provides inference + base vs fine-tuned comparison
8. Evaluation: BLEU, ROUGE-L, perplexity, qualitative examples

## 1. Install Dependencies

Run this cell first. On Colab: **Runtime → Change runtime type → GPU (T4 recommended)**

In [None]:
# Install required packages (optimized for Colab/Kaggle free GPU)
!pip install -q transformers datasets peft accelerate bitsandbytes trl scipy nltk evaluate sentencepiece

# For evaluation metrics
import nltk
nltk.download('punkt_tab', quiet=True)
nltk.download('punkt', quiet=True)

## 2. Imports and Configuration

In [None]:
import json
import torch
from pathlib import Path
from dataclasses import dataclass, field
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import Dataset
from trl import SFTTrainer

# Config
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DATASET_PATH = "data/autism_screening_guidance.jsonl"  # Or upload your JSONL
OUTPUT_DIR = "autism_guidance_model"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
if DEVICE == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 3. Load TinyLlama with 4-bit Quantization

4-bit loading drastically reduces GPU RAM (from ~4GB to ~2GB base model load).

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)

print("Model loaded with 4-bit quantization. Ready for LoRA.")

## 4. Load and Format Instruction Dataset

In [None]:
def load_jsonl_dataset(path: str) -> Dataset:
    """Load JSONL instruction dataset."""
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            item = json.loads(line.strip())
            data.append(item)
    return Dataset.from_list(data)


def format_prompt(example) -> str:
    """Format instruction + input + output for chat-style training."""
    instruction = example.get("instruction", "")
    inp = example.get("input", "")
    output = example.get("output", "")
    if inp:
        text = f"<|user|>\n{instruction}\n{inp}\n<|assistant|>\n{output}"
    else:
        text = f"<|user|>\n{instruction}\n<|assistant|>\n{output}"
    return {"text": text}


# Load dataset. For Colab: upload data/ folder or run create_dataset.py first.
if not Path(DATASET_PATH).exists():
    Path("data").mkdir(exist_ok=True)
    try:
        from create_dataset import create_dataset
        create_dataset(DATASET_PATH)
        print("Generated full dataset from create_dataset.py")
    except Exception:
        D = " This is not a diagnosis. Please consult a healthcare professional."
        mini = [
            {"instruction": "What are early signs of autism?", "input": "", "output": "Early signs may include limited eye contact, delayed speech, repetitive behaviors, and reduced social interaction. Consider screening with a healthcare provider." + D},
            {"instruction": "When should I seek screening?", "input": "", "output": "Seek screening if you notice delayed speech, few gestures, or social differences. Routine screening at 18-24 months is recommended." + D},
        ] * 50
        with open(DATASET_PATH, "w", encoding="utf-8") as f:
            for ex in mini: f.write(json.dumps(ex, ensure_ascii=False) + "\n")
        print(f"Created minimal fallback dataset: {len(mini)} examples")

raw_dataset = load_jsonl_dataset(DATASET_PATH)
formatted = raw_dataset.map(format_prompt, remove_columns=raw_dataset.column_names)

print(f"Loaded {len(formatted)} examples")
print("Sample formatted text:")
print(formatted[0]["text"][:500] + "...")

## 5. Apply LoRA

## 5a. Base Model Outputs (Run Before LoRA)

Capture base model responses for later comparison. Run this before applying LoRA.

In [None]:
def generate_response(model, tokenizer, question, max_new_tokens=256, temperature=0.7, do_sample=True):
    prompt = f"<|user|>\n{question}\n<|assistant|>\n"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=do_sample, pad_token_id=tokenizer.eos_token_id)
    full = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "<|assistant|>" in full:
        return full.split("<|assistant|>")[-1].strip()
    return full.strip()

TEST_PROMPTS = ["What are early signs of autism in toddlers?", "When should I seek autism screening?", "Is autism caused by vaccines?"]
BASE_OUTPUTS = [generate_response(model, tokenizer, q) for q in TEST_PROMPTS]
print("Base outputs saved. Proceed to LoRA.")

In [None]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

## 6. Training

Settings optimized for free Colab T4 (16GB) or Kaggle P100.

In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=2,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    fp16=True,
    logging_steps=20,
    save_strategy="epoch",
    report_to="none",
    remove_unused_columns=False,
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted,
    dataset_text_field="text",
    max_seq_length=512,
    packing=False,
)

print("Starting training...")
trainer.train()
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Training complete. Model saved.")

## 7. Inference Function

In [None]:
def generate_response(
    model,
    tokenizer,
    question: str,
    max_new_tokens=256,
    temperature=0.7,
    do_sample=True,
) -> str:
    prompt = f"<|user|>\n{question}\n<|assistant|>\n"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
            pad_token_id=tokenizer.eos_token_id,
        )
    full = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "<|assistant|>" in full:
        return full.split("<|assistant|>")[-1].strip()
    return full.strip()

## 8. Base vs Fine-tuned Comparison

In [None]:
# Compare with BASE_OUTPUTS from Section 5a (run that before LoRA)
try:
    _ = TEST_PROMPTS
except NameError:
    TEST_PROMPTS = ["What are early signs of autism in toddlers?", "When should I seek autism screening?", "Is autism caused by vaccines?"]
    BASE_OUTPUTS = ["(Run Section 5a before LoRA to capture base outputs)"] * 3

print("=" * 60)
print("BASE vs FINE-TUNED MODEL")
print("=" * 60)

for i, q in enumerate(TEST_PROMPTS):
    base_out = BASE_OUTPUTS[i] if i < len(BASE_OUTPUTS) else "N/A"
    ft_out = generate_response(model, tokenizer, q)
    print(f"\nQuestion: {q}\n")
    print(f"Base:      {(base_out[:400] + '...') if len(str(base_out)) > 400 else base_out}")
    print(f"\nFine-tuned: {(ft_out[:400] + '...') if len(ft_out) > 400 else ft_out}")
    print("-" * 60)

## 9. Evaluation Metrics

Compute BLEU, ROUGE-L, and perplexity on a held-out subset.

In [None]:
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
from nltk.tokenize import word_tokenize
import evaluate

rouge = evaluate.load("rouge")

def compute_metrics(predictions, references):
    bleu = corpus_bleu([[r.split()] for r in references], [p.split() for p in predictions])
    rouge_result = rouge.compute(predictions=predictions, references=references)
    return {"bleu": bleu, "rouge_l": rouge_result["rougeL"]}

# Sample eval set from dataset
eval_data = raw_dataset.train_test_split(test_size=0.1, seed=42)
eval_examples = eval_data["test"][:50]

references = [ex["output"] for ex in eval_examples]
predictions = [generate_response(model, tokenizer, ex["instruction"]) for ex in eval_examples]

metrics = compute_metrics(predictions, references)
print(f"BLEU: {metrics['bleu']:.4f}")
print(f"ROUGE-L: {metrics['rouge_l']:.4f}")

In [None]:
# Perplexity (on a small batch)
def compute_perplexity(model, tokenizer, texts, max_length=256):
    total_loss = 0
    count = 0
    model.eval()
    for t in texts[:20]:
        inputs = tokenizer(t, return_tensors="pt", truncation=True, max_length=max_length).to(model.device)
        with torch.no_grad():
            out = model(**inputs, labels=inputs["input_ids"])
        total_loss += out.loss.item()
        count += 1
    avg_loss = total_loss / count if count else 0
    return torch.exp(torch.tensor(avg_loss)).item()

eval_texts = [ex["text"] for ex in formatted.select(range(min(50, len(formatted))))]
ppl = compute_perplexity(model, tokenizer, eval_texts)
print(f"Perplexity: {ppl:.2f}")

## 10. Experiments Table (Fill in after runs)

| Experiment | LR | Batch | Grad Accum | Epochs | GPU RAM | Loss | BLEU | ROUGE-L | PPL |
|------------|-----|-------|------------|--------|---------|------|------|---------|-----|
| Base       | –   | –     | –          | –      | ~4GB    | –    | –    | –       | –   |
| Exp 1      | 2e-5| 2     | 8          | 2      | ~8GB    | –    | –    | –       | –   |
| Exp 2      | 5e-5| 4     | 4          | 1      | ~9GB    | –    | –    | –       | –   |

Use this table in your report.

## 11. Interactive Chat (Quick Test)

In [None]:
def chat(question: str) -> str:
    return generate_response(model, tokenizer, question)

# Example:
print(chat("What are early signs of autism in toddlers?"))

In [None]:
# Optional: Gradio UI (uncomment to use)
# import gradio as gr
# gr.Interface(fn=chat, inputs="text", outputs="text", title="Autism Screening Guidance").launch(share=True)