# DDXPlus Multistep QA — LLM Rewrite/Test (Ollama → OpenAI)

This notebook rewrites `question_stem` into a consistent **"Suppose X happens, how will it affect Y?"** style.

- First test uses **Ollama** via its OpenAI-compatible endpoint: `http://localhost:11434/v1`.
- Later you can switch to **OpenAI** by changing `BASE_URL` to `https://api.openai.com/v1` and setting `OPENAI_API_KEY`.


## 1. Imports

In [None]:
import os
import json
import time
from typing import Any, Dict, List, Optional

import requests

## 2. Configuration

In [None]:
# Input/output
INPUT_JSONL = 'DDXPlus_CausalQA_multistep.jsonl'
OUTPUT_JSONL = 'DDXPlus_CausalQA_multistep_rewritten.jsonl'

# OpenAI-compatible endpoint
# - Ollama: http://localhost:11434/v1
# - OpenAI: https://api.openai.com/v1
BASE_URL = os.environ.get('OPENAI_BASE_URL', 'http://localhost:11434/v1')
API_KEY = os.environ.get('OPENAI_API_KEY', 'ollama')

# Model name
# - Ollama example: llama3.1:8b
# - OpenAI example: gpt-4o-mini
MODEL = os.environ.get('OPENAI_MODEL', 'llama3.1:8b')

# Rewrite settings
TEMPERATURE = 0.7
MAX_TOKENS = 120
TIMEOUT_S = 60

# For quick testing
N_TEST = 5
MAX_REWRITE = 0  # 0 = all rows

print('BASE_URL:', BASE_URL)
print('MODEL:', MODEL)

## 3. Load JSONL

In [None]:
def load_jsonl(path: str) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows


rows = load_jsonl(INPUT_JSONL)
print('rows:', len(rows))
print('keys:', sorted(rows[0].keys()) if rows else [])
print('sample question_stem:', rows[0].get('question_stem', '') if rows else '')

## 4. OpenAI-compatible `chat.completions` call

This works for both Ollama and OpenAI by only changing `BASE_URL` + `API_KEY`.

In [None]:
def chat_completions(
    *,
    base_url: str,
    api_key: str,
    model: str,
    messages: List[Dict[str, str]],
    temperature: float = 0.7,
    max_tokens: int = 256,
    timeout_s: int = 60,
) -> str:
    url = base_url.rstrip('/') + '/chat/completions'
    headers = {
        'Content-Type': 'application/json',
        'Authorization': f'Bearer {api_key}',
    }
    payload = {
        'model': model,
        'messages': messages,
        'temperature': float(temperature),
        'max_tokens': int(max_tokens),
    }
    r = requests.post(url, headers=headers, json=payload, timeout=timeout_s)
    r.raise_for_status()
    data = r.json()
    return data['choices'][0]['message']['content']


# Quick connectivity test (should print model output)
test_out = chat_completions(
    base_url=BASE_URL,
    api_key=API_KEY,
    model=MODEL,
    messages=[{'role': 'user', 'content': 'Reply with exactly: OK'}],
    temperature=0.0,
    max_tokens=8,
    timeout_s=TIMEOUT_S,
)
print('connectivity:', test_out.strip())

## 5. Rewrite `question_stem` (LLM)

We use `cause_event` + `outcome_base` (+ optional `outcome_polarity`) as the *meaning anchor*.

Important: the rewritten question must **NOT** state the answer direction.

In [None]:
def outcome_base_to_phrase(outcome_base: str, outcome_polarity: Optional[str] = None) -> str:
    ob = (outcome_base or '').strip()
    if not ob:
        return 'the outcome'
    lower = ob.lower()
    if lower.endswith(' probability'):
        name = ob[: -len(' probability')].strip()
        pol = (outcome_polarity or '').strip().lower()
        if pol in ('more', 'less'):
            return f'the {pol} probability of {name}'
        return f'the probability of {name}'
    if lower.endswith(' rate'):
        name = ob[: -len(' rate')].strip()
        return f'the rate of {name}'
    return ob


def build_base_question(cause_event: str, outcome_base: str, outcome_polarity: Optional[str] = None) -> str:
    x = (cause_event or '').strip().rstrip('.')
    y = outcome_base_to_phrase(outcome_base, outcome_polarity=outcome_polarity)
    if not x:
        x = 'the situation described above'
    return f'Suppose {x} happens, how will this affect {y}?'


def rewrite_question_stem(item: Dict[str, Any]) -> str:
    base_q = build_base_question(
        item.get('cause_event', ''),
        item.get('outcome_base', ''),
        outcome_polarity=item.get('outcome_polarity'),
    )
    system = 'You rewrite causal questions into fluent English without changing meaning.'
    user = f"""Rewrite the following question into ONE grammatical English question.

Rules:
- Preserve meaning: same cause event X and same outcome target Y (including any MORE/LESS outcome meta).
- DO NOT state the direction of effect for Y (do not say increase/decrease, more/less likely, etc.).
- If Y contains the words 'more' or 'less', they must ONLY appear as part of the outcome target phrase.
- Output ONLY the rewritten question sentence.
- Must be ONE sentence and end with a '?'.
- Must start with the word "Suppose".

Question:
{base_q}
"""
    out = chat_completions(
        base_url=BASE_URL,
        api_key=API_KEY,
        model=MODEL,
        messages=[
            {'role': 'system', 'content': system},
            {'role': 'user', 'content': user},
        ],
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        timeout_s=TIMEOUT_S,
    )
    q = out.strip().splitlines()[0].strip()
    if not q.endswith('?'):
        q = q.rstrip('.') + '?'
    return q


def rewrite_item(item: Dict[str, Any]) -> Dict[str, Any]:
    new_item = dict(item)
    new_item['question_stem_original'] = item.get('question_stem', '')
    new_item['question_stem'] = rewrite_question_stem(item)
    return new_item

## 6. Quick test (Ollama)

In [None]:
for i, item in enumerate(rows[:N_TEST]):
    rewritten = rewrite_item(item)
    print(f'\n--- sample {i} ---')
    print('original:', rewritten['question_stem_original'])
    print('rewritten:', rewritten['question_stem'])
    time.sleep(0.2)

## 7. Rewrite and save JSONL (optional)

Set `MAX_REWRITE=0` to rewrite all rows.

In [None]:
def write_jsonl(items: List[Dict[str, Any]], path: str) -> None:
    with open(path, 'w', encoding='utf-8') as f:
        for obj in items:
            f.write(json.dumps(obj, ensure_ascii=False) + '\n')


n = len(rows) if MAX_REWRITE == 0 else min(len(rows), MAX_REWRITE)
out_rows: List[Dict[str, Any]] = []
for i in range(n):
    out_rows.append(rewrite_item(rows[i]))
    if (i + 1) % 10 == 0:
        print('rewritten', i + 1, '/', n)
    time.sleep(0.1)

write_jsonl(out_rows, OUTPUT_JSONL)
print('saved to:', OUTPUT_JSONL)