
# 11.4.1 Enhancing AI‑Driven Customer Support  
**Chapter 11: Coreference Resolution & Text Entailment**

This notebook demonstrates a practical, end‑to‑end use case for **AI‑driven customer support** that aligns with Chapter 11’s focus on **coreference resolution** and **text entailment**.  
We build a retrieval‑augmented question answering (QA) system over a small customer‑support knowledge base, and show:
- **Coreference‑aware retrieval** (simple, spaCy‑based heuristic) to clarify pronouns in user queries (e.g., “When can I cancel it?” → replace *it* with the last detected entity).
- **Text entailment checking** (NLI scoring) to validate that answers are supported by retrieved context.
- **Hybrid QA** with both **SLM extractive QA** (DistilBERT) and **LLM generative QA** (Gemma via Hugging Face Inference API with a BitNet fallback and a local tiny fallback).

> **What you’ll get:** a working, reproducible pipeline that loads sample docs, indexes with embeddings + Chroma, retrieves passages, performs coref‑aware querying, answers with either extractive or generative models, and scores answers with NLI entailment.


In [None]:

# --- Install dependencies (CPU friendly defaults) ---
import sys, subprocess

def pip_install(pkgs):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet"] + pkgs)

# Core libs
pip_install([
    "langchain==0.2.11",
    "langchain-community==0.2.10",
    "langchain-chroma==0.1.2",
    "langchain-huggingface==0.0.3",
    "chromadb==0.5.5",
    "sentence-transformers==3.0.1",
    "transformers==4.43.4",
    "accelerate==0.34.2",
    "torch",                # will select a CPU wheel by default
    "spacy==3.7.4",
])

# Download a lightweight English pipeline for spaCy (used in coref heuristic)
subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])

print("✅ Dependencies installed.")


In [None]:

# --- Imports & basic setup ---
import os, shutil, random, json, math
from typing import List, Dict, Any, Tuple

import torch
from transformers import (
    pipeline,
    AutoTokenizer, AutoModelForSequenceClassification
)

import spacy

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, HuggingFacePipeline
from langchain.chains import RetrievalQA

random.seed(42)
torch.manual_seed(42)

# Path for vector store persistence
PERSIST_DIR = "./chroma_customer_support"

# Hugging Face token for hosted LLMs (optional but recommended for Gemma/BitNet)
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")

print("HF_TOKEN found?" , bool(HF_TOKEN))
print("Using persist dir:", PERSIST_DIR)



## Sample Customer Support Knowledge Base
To keep the demo self‑contained, we’ll synthesize a small but realistic knowledge base for a fictional SaaS platform:

- **Accounts & Security:** password reset, multi‑factor auth, account recovery  
- **Billing & Plans:** trial, refunds, invoices, upgrades/downgrades, VAT/GST  
- **Support & SLA:** hours, severity levels, response times  
- **API & Rate Limits:** keys, throttling, status codes  
- **Compliance & Privacy:** data retention, GDPR/CCPA, data residency  
- **Cancellations:** end‑of‑term policy, prorations

These documents are kept short and friendly for quick iteration.


In [None]:

# --- Create a small in-memory corpus of customer-support docs ---
docs = [
    {
        "id": "accounts_security",
        "title": "Accounts & Security",
        "text": (
            "Password resets are available via the 'Forgot Password' link on the login page. "
            "Users receive a one-time reset email that expires in 30 minutes. "
            "We support multi-factor authentication via authenticator apps. "
            "If you cannot access your email, contact support to verify identity and trigger recovery."
        )
    },
    {
        "id": "billing_plans",
        "title": "Billing & Plans",
        "text": (
            "We offer Free, Pro, and Enterprise plans. Trials last 14 days on Pro. "
            "Refunds are available within 7 days of charge for monthly Pro, subject to fair use. "
            "Invoices are emailed automatically and accessible from the Billing portal. "
            "VAT/GST is calculated based on your billing address."
        )
    },
    {
        "id": "support_sla",
        "title": "Support & SLA",
        "text": (
            "Support is available 24/7 via email and chat for Pro and Enterprise. "
            "SLA response times by severity: Sev-1 within 1 hour, Sev-2 within 4 hours, Sev-3 within 1 business day. "
            "Free plan users can access our community forum and knowledge base."
        )
    },
    {
        "id": "api_limits",
        "title": "API & Rate Limits",
        "text": (
            "All API requests require an API key in the Authorization header. "
            "The default rate limit is 100 requests per minute per key. "
            "Bursting may be temporarily allowed; exceeding limits returns HTTP 429 with a Retry-After header."
        )
    },
    {
        "id": "privacy_compliance",
        "title": "Compliance & Privacy",
        "text": (
            "We comply with GDPR and CCPA. "
            "Customer data is retained for 30 days after account cancellation unless retention is required by law. "
            "Enterprise customers can request EU-only data residency."
        )
    },
    {
        "id": "cancellation_policy",
        "title": "Cancellation Policy",
        "text": (
            "You can cancel any time from the Billing portal. "
            "For monthly plans, cancellation takes effect at the end of the current billing period; no partial refunds. "
            "Annual plans are non-refundable after 30 days, except where required by local law."
        )
    }
]

print(f"Loaded {len(docs)} documents.")


In [None]:

# --- Chunk, embed, and index in Chroma ---
# Clean up previous index if it exists
if os.path.exists(PERSIST_DIR):
    shutil.rmtree(PERSIST_DIR)

# Split into chunks
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = []
for d in docs:
    for chunk in splitter.split_text(d["text"]):
        chunks.append({"page_content": chunk, "metadata": {"source": d["id"], "title": d["title"]}})

print("Total chunks:", len(chunks))

# Embeddings (SLM): all-MiniLM-L6-v2 is small & strong for retrieval
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# Vector store
vectordb = Chroma.from_documents(
    documents=[
        __import__("langchain").schema.Document(page_content=c["page_content"], metadata=c["metadata"])
        for c in chunks
    ],
    embedding=embeddings,
    persist_directory=PERSIST_DIR
)
retriever = vectordb.as_retriever(search_kwargs={"k": 4})
print("✅ Vector store ready.")



## Coreference Resolution (Heuristic) & Text Entailment (NLI)
- **Coreference:** We implement a lightweight, **spaCy‑based heuristic**: find the most recent named entity and replace ambiguous pronouns in the question. This is not as strong as neural coreference models but is simple and fast.
- **Entailment:** We score whether the retrieved **context entails the answer** using a compact NLI model. If entailment is low, we can flag the answer as potentially unsupported.


In [None]:

# --- Coreference heuristic with spaCy (pronoun → last entity) ---
nlp = spacy.load("en_core_web_sm")

_PRONOUNS = {"it","they","them","he","she","him","her","its","their","theirs","his","hers","this","that"}

def resolve_coref(question: str, chat_history: str = "") -> str:
    '''Very lightweight heuristic:
    - Look at chat_history + question.
    - Take the last named entity (ORG/PRODUCT/PERSON/GPE).
    - Replace ambiguous pronouns in the question with that entity.
    '''
    doc = nlp((chat_history + " " + question).strip())
    entities = [ent.text for ent in doc.ents if ent.label_ in ("ORG","PRODUCT","PERSON","GPE")]
    target = entities[-1] if entities else None
    if not target:
        return question

    q_doc = nlp(question)
    tokens = []
    for t in q_doc:
        if t.lower_ in _PRONOUNS:
            tokens.append(target)
        else:
            tokens.append(t.text)
    return " ".join(tokens)

# --- Entailment (NLI) scorer ---
NLI_MODEL = "cross-encoder/nli-deberta-base"  # compact & accurate
nli_tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL)
nli_model = AutoModelForSequenceClassification.from_pretrained(NLI_MODEL)

def nli_entailment(premise: str, hypothesis: str) -> (str, float):
    '''Returns (label, score) where label in {'ENTAILMENT','NEUTRAL','CONTRADICTION'}'''
    inputs = nli_tokenizer(premise, hypothesis, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        logits = nli_model(**inputs).logits
    probs = torch.softmax(logits, dim=-1)[0].tolist()
    idx = int(torch.argmax(logits, dim=-1))
    label = nli_model.config.id2label[idx]
    score = float(probs[idx])
    return label.upper(), score

print("✅ Coref heuristic & NLI loaded.")



## Extractive QA (SLM: DistilBERT)
We use a classic extractive QA pipeline: **`distilbert-base-uncased-distilled-squad`**.  
Given the retrieved context, it extracts the best answer span.


In [None]:

# --- Extractive QA pipeline ---
qa_extractive = pipeline(
    "question-answering",
    model="distilbert-base-uncased-distilled-squad",
    tokenizer="distilbert-base-uncased-distilled-squad",
)

def ask_question_extractive(question: str, chat_history: str = "", top_k:int = 4):
    q_resolved = resolve_coref(question, chat_history=chat_history)
    docs = retriever.get_relevant_documents(q_resolved)
    if not docs:
        return {"answer": "", "context": "", "entailment": ("NEUTRAL", 0.0), "resolved_question": q_resolved}
    # Use the top document for extractive QA context
    context = docs[0].page_content
    result = qa_extractive(question=q_resolved, context=context)
    label, score = nli_entailment(context, result["answer"])
    return {
        "answer": result["answer"],
        "score": float(result.get("score", 0.0)),
        "resolved_question": q_resolved,
        "context_preview": context[:400] + ("..." if len(context) > 400 else ""),
        "entailment": (label, score),
        "sources": [d.metadata for d in docs]
    }

print("✅ Extractive QA ready.")



## Generative QA (LLM)
We prefer a hosted LLM via **Hugging Face Inference API** for convenience:
- Primary: `google/gemma-7b-it`
- Fallback: `microsoft/bitnet-b1.58-2B-4T` *(if available)*
- Local tiny fallback (no token required): `distilgpt2` wrapped with `HuggingFacePipeline`

> Set your token in the environment as `HF_TOKEN` or `HUGGINGFACEHUB_API_TOKEN` to use hosted models.


In [None]:

# --- Initialize LLM(s) ---
llm = None
llm_name = None

def try_init_endpoint(repo_id: str):
    global llm, llm_name
    try:
        _llm = HuggingFaceEndpoint(
            repo_id=repo_id,
            huggingface_api_token=HF_TOKEN,
            task="text-generation",
            temperature=0.2,
            max_new_tokens=256,
            repetition_penalty=1.05,
        )
        # A small probe call to validate connectivity
        _ = _llm.invoke("Say 'ready'.")
        llm = _llm
        llm_name = repo_id
        print(f"✅ Using hosted LLM: {repo_id}")
        return True
    except Exception as e:
        print(f"Could not init {repo_id} via endpoint -> {e}")
        return False

if HF_TOKEN:
    if not try_init_endpoint("google/gemma-7b-it"):
        # BitNet repo name may vary; this is best-effort. If it fails, we fall back locally.
        try_init_endpoint("microsoft/bitnet-b1.58-2B-4T")

if llm is None:
    # Local tiny fallback
    from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline as hf_pipeline
    tok = AutoTokenizer.from_pretrained("distilgpt2")
    mdl = AutoModelForCausalLM.from_pretrained("distilgpt2")
    gen_pipe = hf_pipeline("text-generation", model=mdl, tokenizer=tok, max_new_tokens=256)
    llm = HuggingFacePipeline(pipeline=gen_pipe)
    llm_name = "distilgpt2 (local fallback)"
    print("✅ Using local tiny fallback LLM: distilgpt2")

print("LLM initialized as:", llm_name)


In [None]:

# --- Build a standard RetrievalQA chain (stuff prompt) ---
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=retriever,
    return_source_documents=True
)

def format_context(docs) -> str:
    blocks = []
    for i, d in enumerate(docs, 1):
        meta = d.metadata
        blocks.append(f"[{i}] {meta.get('title','')} ({meta.get('source','')})\n{d.page_content}")
    return "\n\n".join(blocks)

SYSTEM_PROMPT = (
    "You are a helpful, precise customer-support assistant. "
    "Answer ONLY from the provided context. If the answer is not present, say you don't know."
)

def ask_question_generative(question: str, chat_history: str = ""):
    q_resolved = resolve_coref(question, chat_history=chat_history)
    result = qa_chain({"query": f"{SYSTEM_PROMPT}\\n\\nQuestion: {q_resolved}"})
    answer = result["result"]
    sources = result.get("source_documents", []) or []
    context_text = format_context(sources) if sources else ""
    # Compute entailment between concatenated context and answer
    label, score = nli_entailment(context_text, answer) if context_text else ("NEUTRAL", 0.0)
    return {
        "answer": answer.strip(),
        "resolved_question": q_resolved,
        "entailment": (label, score),
        "sources": [s.metadata for s in sources],
        "context_preview": (context_text[:600] + ("..." if len(context_text) > 600 else "")) if context_text else ""
    }

print("✅ Generative QA chain ready.")



## Try it out
Ask a few common customer‑support questions. We’ll show both **extractive** and **generative** answers plus **entailment** scores.


In [None]:

test_questions = [
    "How do I reset my password?",
    "What are the SLA response times for Sev-1 and Sev-2?",
    "Can I get a refund on Pro?",
    "What is the API rate limit and what happens if I exceed it?",
    "Do you support EU-only data residency?",
    "If I cancel it now, when does it take effect?"
]

for q in test_questions:
    print("="*90)
    print("Q:", q)
    print("--- Extractive QA ---")
    ex = ask_question_extractive(q)
    print(json.dumps(ex, indent=2))

    print("\n--- Generative QA ---")
    ge = ask_question_generative(q)
    print(json.dumps(ge, indent=2))

print("\nTip: Set HF_TOKEN for better LLM results (Gemma via Inference API).")


In [None]:

# --- Ask your own question here ---
user_q = "When are refunds available for monthly Pro?"
print("Q:", user_q)
print("\nExtractive:", json.dumps(ask_question_extractive(user_q), indent=2))
print("\nGenerative:", json.dumps(ask_question_generative(user_q), indent=2))



## Deployment Options & Scalability
- **Local / Internal:** Run this pipeline inside your private network. Persist the Chroma index and secure the LLM endpoints with org‑scoped tokens.
- **Cloud:** Use managed vector DBs (e.g., a server‑hosted Chroma/PGVector) and LLM endpoints on Hugging Face or other providers. Horizontal scale the retriever API.
- **Containerization (Docker):** Package the notebook code into a lightweight API service. Add observability for retrieval quality (hit rate, latency) and answer safety (entailment score thresholds).
- **Chapter alignment:** The pipeline demonstrates **coreference‑aware retrieval** and **NLI‑based validation**, two core capabilities from Chapter 11.



### Troubleshooting
- If the **hosted LLM** fails (no token or repo unavailable), the notebook automatically falls back to **`distilgpt2`**.  
- If **PyTorch** installation fails on your platform, install it with your platform‑specific command from the official PyTorch site and rerun the install cell.
- For **better coreference**, replace the heuristic with a neural coref model (e.g., via spaCy/AllenNLP) if your environment permits.
- For **faster indexing**, switch to larger batch embeddings or a GPU‑enabled environment.
