# Summary
- Ports `src/reasoning_shap.py` into an interactive workflow so you can inspect reasoning-step SHAP analysis directly in Jupyter while experimenting with prompts and models.

# Inputs
- Problem statement text for the reasoning task.
- Optional list of `ReasoningStep` objects or an Ollama endpoint for automatic step generation.
- Model/vectorizer selection (`OllamaModel`, embeddings, or the dummy demo model provided here).

# Outputs
- `pandas.DataFrame` of reasoning-step coalitions with similarity scores.
- Bar chart of reasoning-step importance values.
- Color-highlighted reasoning narrative that emphasizes high-importance sentences.


In [None]:
from pathlib import Path
import sys
import math
import hashlib
from typing import List, Optional
import re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors
from IPython.display import display, HTML

PROJECT_ROOT = Path.cwd()
SRC_PATH = PROJECT_ROOT / 'src'
if str(SRC_PATH) not in sys.path:
    sys.path.append(str(SRC_PATH))

plt.style.use('seaborn-v0_8')


In [None]:
from reasoning_shap import ReasoningSHAP, ReasoningStep
from base import ModelBase, TfidfTextVectorizer, EmbeddingVectorizer, OllamaModel


class DummyReasoningModel(ModelBase):
    '''Deterministic stand-in for Ollama that scores prompts by coverage.'''

    def __init__(self, expected_steps: Optional[int] = None, noise_scale: float = 0.05):
        super().__init__(model_name='dummy-reasoning-model')
        self.expected_steps = expected_steps
        self.noise_scale = noise_scale

    def generate(self, prompt: str) -> str:
        total_steps = self.expected_steps if self.expected_steps is not None else max(prompt.count('Step '), 1)
        present_steps = prompt.count('Step ')
        coverage = present_steps / max(total_steps, 1)
        digest = hashlib.sha1(prompt.encode('utf-8')).hexdigest()
        jitter = ((int(digest[:6], 16) % 2000) / 1000 - 1.0) * self.noise_scale
        score = float(np.clip(coverage + jitter, 0.0, 1.0))
        return (
            f"Confidence Score: {score:.3f}
"
            f"Steps Included: {present_steps}/{total_steps}
"
            'Reasoning Summary:
'
            f"{prompt}"
        )


class NotebookReasoningSHAP(ReasoningSHAP):
    '''ReasoningSHAP variant with a lightweight fallback attribution routine.'''

    def _calculate_shapley_values(self, df: pd.DataFrame, content):
        try:
            return super()._calculate_shapley_values(df, content)
        except (ModuleNotFoundError, ImportError) as exc:
            self._debug_print(f"Falling back to simplified attribution because {exc.__class__.__name__}: {exc}")
            return self._fallback_shapley_values(df, content)

    def _fallback_shapley_values(self, df: pd.DataFrame, content) -> dict:
        steps = self._get_samples(content)
        if not steps:
            return {}

        def normalize_indexes(indexes):
            if indexes is None:
                return tuple()
            if isinstance(indexes, tuple):
                return indexes
            if isinstance(indexes, list):
                return tuple(indexes)
            if hasattr(indexes, 'tolist'):
                return tuple(indexes.tolist())
            return tuple(indexes)

        contributions = {}
        for idx, step in enumerate(steps):
            target = idx + 1
            contain = df[df['Indexes'].apply(lambda ids, t=target: t in normalize_indexes(ids))]
            without = df[df['Indexes'].apply(lambda ids, t=target: t not in normalize_indexes(ids))]
            contain_score = float(contain['Similarity'].mean()) if not contain.empty else 0.0
            without_score = float(without['Similarity'].mean()) if not without.empty else 0.0
            contributions[f"{step}_{idx + 1}"] = contain_score - without_score

        min_value = min(contributions.values())
        shifted = {k: v - min_value for k, v in contributions.items()}
        total = sum(shifted.values())
        if total <= 0:
            uniform = 1.0 / max(len(shifted), 1)
            return {k: uniform for k in shifted}
        return {k: v / total for k, v in shifted.items()}


## Configure the analyzer
Toggle `USE_OLLAMA` to switch between the dummy demo pipeline and the full Ollama-backed workflow.  \
Running the full pipeline requires `optuna`, `shap`, and `xgboost` plus an accessible Ollama endpoint.


In [None]:
import os

USE_OLLAMA = False  # set to True for the full Phi-4 reasoning pipeline

problem = '''
A train leaves Station A at 9:00 AM traveling at 60 mph toward Station B.
Another train leaves Station B at 10:00 AM traveling at 80 mph toward Station A.
The distance between the stations is 280 miles.
At what time will the trains meet?
'''.strip()

manual_steps: List[ReasoningStep] = [
    ReasoningStep(
        step_number=1,
        description='Combine train speeds',
        content=(
            'When two trains head toward each other, their individual speeds add. '
            'Add 60 mph and 80 mph to obtain a combined closing speed of 140 mph.'
        ),
    ),
    ReasoningStep(
        step_number=2,
        description='Relate distance and time',
        content=(
            'The distance between the stations is 280 miles. '
            'Time to meet equals distance divided by closing speed.'
        ),
    ),
    ReasoningStep(
        step_number=3,
        description='Compute meeting interval',
        content=(
            'Divide 280 miles by 140 mph to get 2 hours. '
            'This interval is measured from the moment the first train departs Station A.'
        ),
    ),
    ReasoningStep(
        step_number=4,
        description='Account for staggered departures',
        content=(
            'Train B leaves one hour after Train A. '
            'Subtract the one-hour head start to see that Train B travels for only 1 hour before meeting.'
        ),
    ),
    ReasoningStep(
        step_number=5,
        description='State the meeting time',
        content=(
            'Train A has been on the track for 2 hours, so the trains meet at 11:00 AM. '
            'Train B has been moving for 1 hour and arrives at the same moment.'
        ),
    ),
]

if USE_OLLAMA:
    api_url = os.getenv('API_URL', 'http://localhost:11434')
    model_name = os.getenv('OLLAMA_MODEL', 'phi4-reasoning:latest')
    model = OllamaModel(model_name=model_name, api_url=api_url)
    vectorizer = EmbeddingVectorizer(model_name=os.getenv('EMBEDDING_MODEL', 'all-MiniLM-L6-v2'))
    analyzer = NotebookReasoningSHAP(model=model, vectorizer=vectorizer, debug=False)
else:
    model = DummyReasoningModel(expected_steps=len(manual_steps))
    vectorizer = TfidfTextVectorizer()
    analyzer = NotebookReasoningSHAP(model=model, vectorizer=vectorizer, debug=True)


In [None]:
if USE_OLLAMA:
    results_df = analyzer.analyze_reasoning(
        problem=problem,
        steps=None,
        auto_generate_steps=True,
        num_steps=5,
        sampling_ratio=0.3,
        max_combinations=80,
    )
else:
    results_df = analyzer.analyze_reasoning(
        problem=problem,
        steps=manual_steps,
        auto_generate_steps=False,
        sampling_ratio=0.4,
        max_combinations=32,
    )

results_df.sort_values('Similarity', ascending=False).head()


In [None]:
analyzer.plot_importance()


In [None]:
def render_colored_reasoning(analyzer: NotebookReasoningSHAP, top_n: Optional[int] = None) -> None:
    shap_vals = analyzer.shapley_values or {}
    if not shap_vals:
        display(HTML('<p>No Shapley values available to visualize.</p>'))
        return

    steps = analyzer.reasoning_steps
    cmap = plt.cm.YlOrRd
    sorted_items = sorted(shap_vals.items(), key=lambda item: item[1], reverse=True)
    if top_n is not None:
        sorted_items = sorted_items[:top_n]

    values = [value for _, value in sorted_items]
    min_val, max_val = min(values), max(values)

    def scale(value: float) -> float:
        if math.isclose(max_val, min_val):
            return 0.7
        return (value - min_val) / (max_val - min_val)

    html_parts = []
    for key, value in sorted_items:
        match = re.search(r'(\d+)', key)
        if not match:
            continue
        step_num = int(match.group(1))
        step = next((s for s in steps if s.step_number == step_num), None)
        if step is None:
            continue
        color = mcolors.to_hex(cmap(0.35 + 0.5 * scale(value)))
        sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', step.content) if s.strip()]
        sentence_html = ' '.join(
            f"<span style='background-color:{color}; padding:2px 4px; border-radius:4px;'>{sentence}</span>"
            for sentence in sentences
        )
        html_parts.append(
            f"<p><strong>Step {step.step_number}: {step.description}</strong> — importance {value:.2%}<br>{sentence_html}</p>"
        )

    display(HTML(''.join(html_parts)))


In [None]:
render_colored_reasoning(analyzer)
