
# Lab 1B: MCP Sampling (Server requests LLM via Client)

**Goal:** Demonstrate MCP *Sampling* by having an MCP server tool call `ctx.sample(...)` to ask the *client* for an LLM completion.  
In this lab, we **don't use external APIs**. Instead, the client implements a simple `sampling_handler` that returns a mock "summary".

**How it works:**
1. **Cell 1** starts a Streamable HTTP FastMCP server with a tool `SummarizeWithSampling(text, temperature)`.
2. The tool uses `ctx.sample(...)` to request a completion from the *client*.
3. **Cell 2** runs a FastMCP **Client** that supplies a `sampling_handler` (our mock LLM), connects to `http://127.0.0.1:8000/mcp/`, calls the tool, and asserts the result.

> If you restarted the kernel, re-run Cell 1 before Cell 2.


In [1]:
%pip install -q mcp fastmcp 

Note: you may need to restart the kernel to use updated packages.


In [2]:
# Jupyter cell — SERVER (runs in background)
# If it’s a fresh kernel, run once: %pip install -q mcp fastmcp

import asyncio, contextlib
from mcp.server.fastmcp import Context, FastMCP
from mcp.types import SamplingMessage, TextContent

# cancel any prior server task if you re-run the cell
with contextlib.suppress(NameError, asyncio.CancelledError):
    server_task.cancel()
    await asyncio.sleep(0)

mcp = FastMCP(name="Sampling Keywords")

@mcp.tool()
async def extract_keywords(text: str, k: int = 5, ctx: Context = None) -> str:
    """
    Ask the client (via Sampling) to extract top-k keywords from `text`.
    Expected client response: comma-separated, lowercase keywords (no duplicates).
    """
    assert ctx is not None, "Sampling context missing"

    prompt = (
        f"Extract the top {int(k)} keywords from the following text.\n"
        "Rules: lowercase, comma-separated, no duplicates, no extra words.\n\n"
        f"TEXT:\n{text}"
    )

    result = await ctx.session.create_message(
        messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))],
        max_tokens=100,
    )

    if getattr(result.content, "type", "") == "text":
        return result.content.text
    return str(result.content)


async def _run():
    await mcp.run_streamable_http_async()

server_task = asyncio.create_task(_run())
print("Server running at http://localhost:5000")


INFO:     Started server process [31321]
INFO:     Waiting for application startup.


Server running at http://localhost:5000


INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)


INFO:     127.0.0.1:62516 - "GET / HTTP/1.1" 404 Not Found
INFO:     127.0.0.1:62519 - "POST /mcp HTTP/1.1" 307 Temporary Redirect
INFO:     127.0.0.1:62519 - "POST /mcp/ HTTP/1.1" 200 OK
INFO:     127.0.0.1:62520 - "POST /mcp HTTP/1.1" 307 Temporary Redirect
INFO:     127.0.0.1:62521 - "GET /mcp HTTP/1.1" 307 Temporary Redirect
INFO:     127.0.0.1:62520 - "POST /mcp/ HTTP/1.1" 202 Accepted
INFO:     127.0.0.1:62521 - "GET /mcp/ HTTP/1.1" 200 OK
INFO:     127.0.0.1:62522 - "POST /mcp HTTP/1.1" 307 Temporary Redirect
INFO:     127.0.0.1:62522 - "POST /mcp/ HTTP/1.1" 200 OK
INFO:     127.0.0.1:62523 - "POST /mcp HTTP/1.1" 307 Temporary Redirect
INFO:     127.0.0.1:62523 - "POST /mcp/ HTTP/1.1" 200 OK
INFO:     127.0.0.1:62524 - "POST /mcp HTTP/1.1" 307 Temporary Redirect
INFO:     127.0.0.1:62524 - "POST /mcp/ HTTP/1.1" 202 Accepted
INFO:     127.0.0.1:62525 - "DELETE /mcp HTTP/1.1" 307 Temporary Redirect
INFO:     127.0.0.1:62525 - "DELETE /mcp/ HTTP/1.1" 200 OK


In [3]:
import asyncio
from collections import Counter
import re
from mcp import ClientSession, types
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import TextContent
from mcp.shared.context import RequestContext

STOPWORDS = {
    "the","a","an","and","or","but","if","then","else","for","to","of","in","on","at","by",
    "is","are","was","were","be","being","been","with","as","that","this","these","those",
    "it","its","from","into","out","over","under","about","after","before","than","so","such",
    "we","you","they","he","she","i","me","my","our","your","their","them","his","her","their",
}

def simple_keywords(text: str, k: int) -> str:
    words = re.findall(r"[a-zA-Z0-9]+", text.lower())
    words = [w for w in words if w not in STOPWORDS and len(w) > 1]
    counts = Counter(words)
    # Deterministic tie-breaker: sort by (-freq, word)
    ranked = sorted(counts.items(), key=lambda x: (-x[1], x[0]))
    top = [w for w, _ in ranked[:k]]
    return ", ".join(top)

async def sampling_handler(context: RequestContext, params: types.CreateMessageRequestParams) -> types.CreateMessageResult:
    # The server sends a single user TextContent with the prompt (includes k and the TEXT)
    prompt_text = params.messages[0].content.text

    # Try to extract k from the prompt (fallback to 5)
    m = re.search(r"top\s+(\d+)\s+keywords", prompt_text, flags=re.I)
    k = int(m.group(1)) if m else 5

    # Extract the TEXT block
    m2 = re.search(r"TEXT:\s*(.*)\Z", prompt_text, flags=re.I | re.S)
    text = m2.group(1).strip() if m2 else prompt_text

    output = simple_keywords(text, k)

    return types.CreateMessageResult(
        role="assistant",
        content=TextContent(type="text", text=output),
        model="mock-keywords-1.0",
    )

async def main():
    server_url = "http://127.0.0.1:8000/mcp"  # match your server
    sample_text = "Observability with Prometheus, Loki, and Tempo helps engineers diagnose issues in microservices quickly."

    async with streamablehttp_client(server_url) as (read, write, *_):
        async with ClientSession(read, write, sampling_callback=sampling_handler) as session:
            await session.initialize()

            tools = await session.list_tools()
            print("Tools:", [t.name for t in tools.tools])

            result = await session.call_tool("extract_keywords", {"text": sample_text, "k": 5})
            print("Keywords:", result.content[0].text)

await main()


Tools: ['extract_keywords']


Keywords: diagnose, engineers, helps, issues, loki
