# LLM Model

Using ollama and Gemma

In [None]:

# --- Setup ---
import numpy as np
import shap
from ollama import Client

# 1) Connect to local Ollama
#    Default host is http://localhost:11434, change if needed.
client = Client()  # Client(host='http://localhost:11434')

# 2) Map your diagnosis classes to single-character IDs
#    Keep it short and one-token (ASCII uppercase letters are safe).
class_names = list(le_y.classes_)  # or your own list of labels
assert len(class_names) <= 20, "Example uses up to 20 classes (A..T). Increase mapping if needed."

id2char = [chr(ord('A') + i) for i in range(len(class_names))]
char2id = {c: i for i, c in enumerate(id2char)}
label2char = {lbl: id2char[i] for i, lbl in enumerate(class_names)}

# 3) Build a prompt template that forces a one-letter answer
def make_prompt(text, classes=id2char, names=class_names):
    legend = "\n".join(f"{c} = {name}" for c, name in zip(classes, names))
    return (
        "You are a medical text classifier.\n"
        "Given the patient context below, choose the best diagnosis label.\n"
        f"Answer ONLY with one letter from this list and nothing else: {', '.join(classes)}\n\n"
        "Legend:\n"
        f"{legend}\n\n"
        "Patient context:\n"
        f"{text}\n\n"
        "Answer:"
    )

# 4) The probability function SHAP will call
def ollama_predict_proba(texts, model_name="gemma3", top_logprobs=None):
    """
    Returns array of shape (n_samples, n_classes) with probabilities for the one-letter class IDs.
    Uses Ollama /api/generate with logprobs for the FIRST generated token only (num_predict=1).
    """
    n_classes = len(id2char)
    if top_logprobs is None:
        top_logprobs = max(n_classes, 5)  # ensure we capture all candidate letters

    out = np.zeros((len(texts), n_classes), dtype=np.float64)

    for i, t in enumerate(texts):
        prompt = make_prompt(t)

        # Single-token completion; deterministic; request logprobs & the top alternatives
        resp = client.generate(
            model=model_name,
            prompt=prompt,
            stream=False,
            options={
                "temperature": 0.0,
                "num_predict": 1,
                "top_k": 100,     # broad shortlist
                "top_p": 1.0,
            },
            logprobs=True,
            top_logprobs=top_logprobs,
        )

        # Parse logprobs for first generated token
        # The client returns a dict with 'response' and 'logprobs' per token.
        # We expect exactly one generated token (num_predict=1).
        logprobs_list = resp.get("logprobs", [])
        if not logprobs_list:
            # Fallback: if missing, put mass on the chosen token only
            chosen = resp.get("response", "").strip()[:1].upper()
            if chosen in char2id:
                out[i, char2id[chosen]] = 1.0
            else:
                # If the model didn't follow instructions, spread uniformly
                out[i, :] = 1.0 / n_classes
            continue

        first = logprobs_list[0]
        # Build a dict of candidate-letter -> logprob
        # 1) include the actually generated token
        cand_lp = {}
        tok = first.get("token", "")
        lp = first.get("logprob", None)
        if tok and lp is not None and tok in char2id:
            cand_lp[tok] = lp

        # 2) include top alternatives
        for alt in first.get("top_logprobs", []) or []:
            tok2 = alt.get("token", "")
            lp2 = alt.get("logprob", None)
            if tok2 in char2id and lp2 is not None:
                cand_lp[tok2] = lp2

        # If any class letter still missing, assign a very low logprob so softmax doesn't zero them out
        very_low = -50.0
        logits = np.array([cand_lp.get(c, very_low) for c in id2char], dtype=np.float64)

        # Softmax to probabilities
        logits -= logits.max()
        probs = np.exp(logits)
        probs /= probs.sum()

        out[i, :] = probs

    return out

# 5) Build a SHAP explainer for text using the probability function
masker = shap.maskers.Text()  # tokenizer-aware masking for text
explainer = shap.Explainer(lambda texts: ollama_predict_proba(texts, model_name="gemma3"),
                           masker=masker,
                           output_names=class_names)

# 6) Pick a small, representative background set for SHAP (for speed)
X_text_train_small = X_text_train[:50] if len(X_text_train) > 50 else X_text_train
shap_values_text = explainer(X_text_test[:50])  # explain a small batch

# 7) Visualize for a specific class (e.g., index 0)
cls_idx = 0
shap.plots.text(shap_values_text[:, :, cls_idx])  # token highlights for that class

