In [1]:
#%pip -q install -U transformers torch accelerate safetensors ipywidgets openai-harmony kernels

# Jupyter widgets are usually auto-enabled nowadays. If you don't see them,
# save & restart the kernel after this cell.
import ipywidgets as widgets  # sanity import


In [2]:
import os
from datetime import datetime
from typing import List, Dict, Tuple

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Harmony helpers for parsing channels/roles from generated token IDs
from openai_harmony import (
    load_harmony_encoding,
    HarmonyEncodingName,
    Role,
)

# UI
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output

# --- Config (override via environment vars if you want) ---
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
REASONING_EFFORT = os.getenv("REASONING_EFFORT", "high")   # 'low' | 'medium' | 'high'
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "1024"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
TOP_P = float(os.getenv("TOP_P", "0.95"))

print("Using model:", MODEL_ID)

# Load tokenizer & model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto",
)

device = next(model.parameters()).device
print("Loaded on device:", device)

# Load Harmony encoding/parser
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)


Using model: openai/gpt-oss-20b


Fetching 40 files:   0%|          | 0/40 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded on device: cuda:0


In [3]:
def build_system_message() -> Dict[str, str]:
    """Harmony-style system message (adds date + desired reasoning effort)."""
    today = datetime.today().strftime("%Y-%m-%d")
    sys_text = f"""You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: {today}

Reasoning: {REASONING_EFFORT}

# Valid channels: analysis, commentary, final. Channel must be included for every message."""
    return {"role": "system", "content": sys_text}

def apply_chat_template(messages: List[Dict[str, str]]):
    """Use Transformers' chat template (applies Harmony formatting automatically)."""
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
    )
    # Handle case where apply_chat_template returns a tensor directly
    if isinstance(inputs, torch.Tensor):
        return {"input_ids": inputs.to(device)}
    else:
        return {k: v.to(device) for k, v in inputs.items()}

def parse_harmony_from_new_tokens(new_token_ids: List[int]) -> Tuple[str, str]:
    """
    Parse analysis (thinking) and final strings from generated tokens using Harmony.
    """
    msgs = encoding.parse_messages_from_completion_tokens(new_token_ids, Role.ASSISTANT)
    analysis_chunks = []
    final_text = ""
    for m in msgs:
        ch = getattr(m, "channel", "") or ""
        text = getattr(m, "content", "") or ""
        
        # Ensure text is always a string
        if isinstance(text, list):
            text = "".join(str(t) for t in text)
        elif not isinstance(text, str):
            text = str(text)
            
        if ch == "analysis":
            analysis_chunks.append(text)
        elif ch == "final":
            final_text += text
    return ("\n".join(analysis_chunks).strip(), final_text.strip())

def generate_one_turn(messages: List[Dict[str, str]]) -> Dict[str, str]:
    """
    Generate one assistant turn; returns dict with 'analysis' and 'final'.
    Only append the 'final' back to history to maintain clean user-visible chat.
    """
    inputs = apply_chat_template(messages)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            pad_token_id=tokenizer.eos_token_id,
        )
    # Keep only the newly generated tokens
    new_ids = out[0].tolist()[inputs["input_ids"].shape[1]:]
    analysis, final = parse_harmony_from_new_tokens(new_ids)
    return {"analysis": analysis, "final": final, "new_token_ids": new_ids}

In [4]:
# Initialize chat history
chat: list[dict] = [build_system_message()]

# Optional "developer" message per Harmony guidance
chat.append({
    "role": "developer",
    "content": (
        "You are a helpful assistant. Keep answers crisp unless asked otherwise. "
        "Use Markdown formatting when appropriate."
    ),
})

print("Chat primed with system + developer messages.")


Chat primed with system + developer messages.


In [5]:
def chat_cli():
    print("Type 'exit' to quit.\n")
    local_chat = chat.copy()
    while True:
        user_text = input("You: ").strip()
        if not user_text or user_text.lower() in {"exit", "quit"}:
            break
        local_chat.append({"role": "user", "content": user_text})
        turn = generate_one_turn(local_chat)

        print("\n—— Thinking ——————————————")
        print(turn["analysis"] if turn["analysis"] else "(no analysis emitted)")
        print("—— Final ————————————————")
        print(turn["final"])
        print("—————————————————————————\n")

        # Append only the final answer back to history
        local_chat.append({"role": "assistant", "content": turn["final"]})

# Uncomment to use the CLI:
# chat_cli()



In [None]:
show_thinking = widgets.Checkbox(value=True, description="Show thinking (analysis)")
inp = widgets.Text(placeholder="Ask me anything…", description="You")
send_btn = widgets.Button(description="Send", button_style="primary")
out = widgets.Output()

header = widgets.HBox([show_thinking])
composer = widgets.HBox([inp, send_btn])
app = widgets.VBox([header, out, composer])

def render_history(local_chat: list[dict], last_turn: dict | None):
    with out:
        clear_output()
        display(HTML("<h3>Conversation</h3>"))
        for msg in local_chat:
            role = msg.get("role")
            content = msg.get("content", "")
            if role == "user":
                display(HTML(f"<div><b>🧑 You:</b> {content}</div><hr/>"))
            elif role == "assistant":
                display(HTML(f"<div><b>🤖 Assistant:</b><br/>{content}</div><hr/>"))
            elif role in {"system", "developer"}:
                display(HTML(
                    f"""
<details>
  <summary><b>⚙️ {role.title()} message</b> (click to expand)</summary>
  <pre style='white-space:pre-wrap'>{content}</pre>
</details><hr/>
"""))
        if last_turn:
            if show_thinking.value:
                thinking = last_turn.get("analysis", "").strip()
                if thinking:
                    display(HTML(
                        f"""
<details open>
  <summary><b>🧠 Thinking (analysis)</b></summary>
  <pre style='white-space:pre-wrap'>{thinking}</pre>
</details>
"""))
            final = last_turn.get("final", "").strip()
            if final:
                display(HTML(f"<div><b>💬 Final:</b><br/>{final}</div>"))
        display(HTML("<hr/>"))

local_chat = chat.copy()
last_turn = None
render_history(local_chat, last_turn)

def on_send(_):
    global last_turn
    text = inp.value.strip()
    if not text:
        return
    inp.value = ""
    local_chat.append({"role": "user", "content": text})
    render_history(local_chat, None)

    # Generate assistant turn
    turn = generate_one_turn(local_chat)
    last_turn = turn

    # Append only final text back to history to preserve multi-turn continuity
    local_chat.append({"role": "assistant", "content": turn["final"]})
    render_history(local_chat, turn)

send_btn.on_click(on_send)
inp.on_submit(on_send)

display(app)


  inp.on_submit(on_send)


VBox(children=(HBox(children=(Checkbox(value=True, description='Show thinking (analysis)'),)), Output(), HBox(…

In [9]:
local_chat

[{'role': 'system',
  'content': 'You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2025-08-24\n\nReasoning: high\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.'},
 {'role': 'developer',
  'content': 'You are a helpful assistant. Keep answers crisp unless asked otherwise. Use Markdown formatting when appropriate.'},
 {'role': 'user', 'content': 'tell a joke'},
 {'role': 'assistant',
  'content': "text='Why don’t scientists trust atoms?\\n\\nBecause they make up everything!'"}]

In [8]:
# Test block to verify the parse_harmony_from_new_tokens fix
def test_parse_harmony_fix():
    """Test the parse_harmony_from_new_tokens function with various data types."""
    
    # Mock message objects to simulate different content types
    class MockMessage:
        def __init__(self, channel, content):
            self.channel = channel
            self.content = content
    
    # Test case 1: String content (normal case)
    print("Test 1: String content")
    test_msgs_str = [
        MockMessage("analysis", "This is analysis text"),
        MockMessage("final", "This is final text")
    ]
    
    # Manually test the logic
    analysis_chunks = []
    final_text = ""
    for m in test_msgs_str:
        ch = getattr(m, "channel", "") or ""
        text = getattr(m, "content", "") or ""
        
        # Apply the fix
        if isinstance(text, list):
            text = "".join(str(t) for t in text)
        elif not isinstance(text, str):
            text = str(text)
            
        if ch == "analysis":
            analysis_chunks.append(text)
        elif ch == "final":
            final_text += text
    
    print(f"Analysis: {analysis_chunks}")
    print(f"Final: {final_text}")
    print("✓ Test 1 passed\n")
    
    # Test case 2: List content (problematic case)
    print("Test 2: List content")
    test_msgs_list = [
        MockMessage("analysis", ["This", " is", " analysis"]),
        MockMessage("final", ["This", " is", " final"])
    ]
    
    analysis_chunks = []
    final_text = ""
    for m in test_msgs_list:
        ch = getattr(m, "channel", "") or ""
        text = getattr(m, "content", "") or ""
        
        # Apply the fix
        if isinstance(text, list):
            text = "".join(str(t) for t in text)
        elif not isinstance(text, str):
            text = str(text)
            
        if ch == "analysis":
            analysis_chunks.append(text)
        elif ch == "final":
            final_text += text
    
    print(f"Analysis: {analysis_chunks}")
    print(f"Final: {final_text}")
    print("✓ Test 2 passed\n")
    
    # Test case 3: Mixed content types
    print("Test 3: Mixed content types")
    test_msgs_mixed = [
        MockMessage("analysis", 12345),  # Integer
        MockMessage("final", ["Mixed", " content"]),  # List
        MockMessage("analysis", "Normal string"),  # String
        MockMessage("final", " and more")  # String
    ]
    
    analysis_chunks = []
    final_text = ""
    for m in test_msgs_mixed:
        ch = getattr(m, "channel", "") or ""
        text = getattr(m, "content", "") or ""
        
        # Apply the fix
        if isinstance(text, list):
            text = "".join(str(t) for t in text)
        elif not isinstance(text, str):
            text = str(text)
            
        if ch == "analysis":
            analysis_chunks.append(text)
        elif ch == "final":
            final_text += text
    
    print(f"Analysis: {analysis_chunks}")
    print(f"Final: {final_text}")
    print("✓ Test 3 passed\n")
    
    print("All tests passed! The fix should prevent TypeError.")

# Run the test
test_parse_harmony_fix()

Test 1: String content
Analysis: ['This is analysis text']
Final: This is final text
✓ Test 1 passed

Test 2: List content
Analysis: ['This is analysis']
Final: This is final
✓ Test 2 passed

Test 3: Mixed content types
Analysis: ['12345', 'Normal string']
Final: Mixed content and more
✓ Test 3 passed

All tests passed! The fix should prevent TypeError.
