## Core

In [None]:
#| default_exp core

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from dotenv import load_dotenv

In [None]:
#| hide
load_dotenv();

In [None]:
#| exporti
from typing import List, Dict, Tuple
import textwrap
from claudette import Client
from IPython import get_ipython
from IPython.display import display, clear_output, Markdown, Javascript

In [None]:
#| export

# A single cell can contain multiple messages.
# A message is either a user message (starts with %fr) or a bot message (starts with #).
# Both can be multiline.
def parse_cell(
    cell: str # The raw body of the cell
) -> Tuple[List[Dict[str, str]], int]:
    """
    A single cell can contain multiple messages.
    A message is either a user message (starts with %fr) or a bot message (starts with #).
    Both can be multiline.

    Returns: a list of messages (with 'role' and 'content') and the number of %fr magics in the cell
    """
    parsed_lines = []
    num_magic = 0
    for line in cell.split('\n'):
        if line.startswith('%fr'):
            message = {'role': 'user', 'content': line[3:].strip()}
            num_magic += 1
        elif line.strip().startswith('#'):
            message = {'role': 'assistant', 'content': line[1:].strip()}
        else: continue

        if not parsed_lines or parsed_lines[-1]['role'] != message['role']:
            parsed_lines.append(message)
        else:
            parsed_lines[-1]['content'] += ("\n" + message['content'])

    return parsed_lines, num_magic

In [None]:
#| exporti
#| hide
models = [
    'claude-3-opus-20240229',
    'claude-3-5-sonnet-20240620',
    'claude-3-haiku-20240307',
]
chat_client = Client(model=models[1])

magic_count = 0
messages = []

In [None]:
#| export
def fr_line(line: str):
    """The magic function for the %fr magic command."""
    global magic_count, messages
    ip = get_ipython()
    # raw_cell = ip.history_manager.input_hist_raw[-1]
    raw_cell = ip.parent.get_parent()["content"]["code"]

    # The cell might have multiple %lm magics, but we only want to process the last one.
    # Presumably, the previous ones would have been processed already.
    if magic_count <= 0:
        messages, magic_count = parse_cell(raw_cell)


    # This is the last %lm magic invocation of the cell.
    # But we ignore cells that don't have a user message as the last message.
    if magic_count == 1 and len(messages) > 0 and messages[-1]['role'] == 'user' and messages[-1]['content'].strip():
        reply = ""
        display_id = display(Markdown("🚀..."), display_id=True)
        try:
            r = chat_client([m['content'] for m in messages], stream=True)
            for token in r:
                reply += token
                display_id.update(Markdown(reply))

            if reply:
                reply = textwrap.fill(text=reply, width=100, initial_indent="# ", subsequent_indent="# ")
                raw_cell += f"\n{reply}\n\n%fr"
                ip.set_next_input(raw_cell, replace=True)

            clear_output()

        except BaseException as e:
            display_id.update(Markdown(f"🚫 {repr(e)}"))


    magic_count -= 1

## Friend**LL**y

In [None]:
#| eval: False
ip = get_ipython()
ip.register_magic_function(fr_line, 'line', magic_name='fr')

In [None]:
# %fr Hello there! My name is Alex.
# # Hello Alex! It's nice to meet you. I'm an AI assistant created by Anthropic to be helpful,
# # harmless, and honest. How can I help you today?

# %fr

In [None]:
#| export

def load_ipython_extension(ipython):
    ipython.register_magic_function(fr_line, 'line', magic_name='fr')

def unload_ipython_extension(ipython):
    pass

In [None]:
#| export

from time import sleep

def inject_js(js:str):
    """Inject some javascript into the notebook and clear the output to prevent it from running on reload"""
    display(Javascript(js))
    # wait=True seems to be crucial here. Without it, if I run all cells, jupyter
    # still uses the original CodeCell.execute(), which is weird, because I see the
    # injected js code executed immediately.
#     clear_output(wait=True)
    clear_output()

In [None]:
    
def patch_kernel():
    payload = """
    Jupyter.CodeCell.prototype.execute = function (stop_on_error) {
        if (!this.kernel) {
            console.log(i18n.msg._("Can't execute cell since kernel is not set."));
            return;
        }

        if (stop_on_error === undefined) {
            if (this.metadata !== undefined &&
                    this.metadata.tags !== undefined) {
                if (this.metadata.tags.indexOf('raises-exception') !== -1) {
                    stop_on_error = false;
                } else {
                    stop_on_error = true;
                }
            } else {
               stop_on_error = true;
            }
        }

        this.clear_output(false, true);
        var old_msg_id = this.last_msg_id;
        if (old_msg_id) {
            this.kernel.clear_callbacks_for_msg(old_msg_id);
            delete Jupyter.CodeCell.msg_cells[old_msg_id];
            this.last_msg_id = null;
        }
        if (this.get_text().trim().length === 0) {
            // nothing to do
            this.set_input_prompt(null);
            return;
        }
        this.set_input_prompt('*');
        this.element.addClass("running");
        var callbacks = this.get_callbacks();


        const cell_index = Jupyter.notebook.find_cell_index(this)
        const cell_id = this.id
        let extras = {
            cell_index :cell_index,
            current_id: cell_id
        }
        if (this.get_text().trim().startsWith("##fr")) {
            extras = {
                all_cells: Jupyter.notebook.get_cells(),
                ...extras
            }
        }

        this.last_msg_id = this.kernel.execute(
            this.get_text(),
            callbacks,
            {silent: false, store_history: true, stop_on_error : stop_on_error, ...extras });
        Jupyter.CodeCell.msg_cells[this.last_msg_id] = this;
        this.render();
        this.events.trigger('execute.CodeCell', {cell: this});
        var that = this;
        function handleFinished(evt, data) {
            if (that.kernel.id === data.kernel.id && that.last_msg_id === data.msg_id) {
                    that.events.trigger('finished_execute.CodeCell', {cell: that});
                that.events.off('finished_iopub.Kernel', handleFinished);
              }
        }
        this.events.on('finished_iopub.Kernel', handleFinished);
    };
    Jupyter.notebook.events.trigger('set_dirty.Notebook', {value: true});

    """
    inject_js(payload)
    
patch_kernel()

In [None]:
#| export
def add_cell(
        idx:int = None, # Index of the cell to add. If none, add the cell under the selected one.
        cell_type:str = "code" # Type of cell to add. Can be "code", "markdown", "raw"
    ):
    """
    Add a new notebook cell.
    """
    if not idx:
        index_payload = "let index = Jupyter.notebook.get_selected_index()+1;"
    else:
        index_payload = f"let index = {idx}"

    payload = f"""
    {index_payload}

    //console.log("add cell start, ncell=", Jupyter.notebook.ncells())
    Jupyter.notebook.insert_cell_at_index("{cell_type}", index)
    //Jupyter.notebook.insert_cell_below();
    //Jupyter.notebook.events.trigger('set_dirty.Notebook', {{value: true}});
    let cell = Jupyter.notebook.get_cell(index);
    //cell.set_text(`add_cell(${{index + 1}})\\nexecute_cell(${{index+1}})`)
    cell.events.trigger('set_dirty.Notebook', {{value: true}});
    //console.log("add cell end ncell=", Jupyter.notebook.ncells())
    """

    inject_js(payload)

In [None]:
#| export
def update_cell(
    idx:int, # Index of the cell to update. None to update the current cell
    text:str, # Text to set in the cell
    flush:bool = True # Notify Jupyter that the cell has been updated.
    ):
    
    def escape_for_js(text):
        # Use json.dumps to escape the string for JavaScript
        escaped = json.dumps(text)
        # Remove the surrounding quotes added by json.dumps
        escaped = escaped[1:-1]
        # Escape backticks and ${} sequences
        return escaped.replace('`', '\\`').replace('${', '\\${')

    payload = f"""
    let cell = Jupyter.notebook.get_cell({idx})
    cell.set_text(`{escape_for_js(text)}`)
    //cell.events.trigger('set_dirty.Notebook', {{value: true}});
    """
#     print(payload)
    if flush:
         patyload = payload + "\nJupyter.notebook.events.trigger('set_dirty.Notebook', {{value: true}});"
    inject_js(payload)

In [None]:
#| export

def execute_cell(
        idx:int # Index of the cell to execute. They start at 0
    ):
    payload = f"""
    console.log("execute_cell", {idx});
    Jupyter.notebook.events.trigger('set_dirty.Notebook', {{value: true}});
    let cell = Jupyter.notebook.get_cell({idx})
    cell.execute()
    //Jupyter.notebook.execute_cell_range({idx, idx+1})
    //Jupyter.notebook.events.trigger('set_dirty.Notebook', {{value: true}});
    """
    # tt = display(f"About to run the cell {idx}...", display_id=True)
    inject_js(payload)

In [None]:
def render_cell(
    idx:int # Cell to render. 
    ):
    """Render a markdown cell"""
    payload = f"""
    let cell = Jupyter.notebook.get_cell({idx})
    cell.unrender()
    Jupyter.notebook.events.trigger('set_dirty.Notebook', {{value: true}});

    cell.render()

    """

    inject_js(payload)


In [None]:
system = """
You are an ai assistant that can work with jupyter notebooks. Your reply will be rendered as markdown.
Use <code> to add a new code cell that the user can execute.
"""

tokens = []

import json
def fr_cell(line=None, cell=None):
    tokens = []
    ip = get_ipython()
    idx = ip.parent_header["content"]["cell_index"]
    next_cell = idx+1
    raw_cell = ip.parent_header["content"]["code"].strip()
    assert raw_cell.startswith("%%fr")
    raw_cell = raw_cell[4:]
    raw_cell = json.dumps(raw_cell)[1:-1]
    add_cell(next_cell, "markdown")
    reply = ""
    
#     display_id = display(Markdown("🚀..."), display_id=True)

    try:
        add_code = False
        r = chat_client([raw_cell], sp=system, stream=True, stop="</code>")
        for token in r:
            reply += token
            tokens.append(token)
            if "<code>" in reply:
                add_code = True
                split = reply.split("<code>")
                if len(split) == 2:
                    update_cell(next_cell, split[0])
                    render_cell(next_cell)
                    reply = split[1].strip()
                break
            update_cell(next_cell, reply)
            render_cell(next_cell)

        if add_code:
            add_cell(next_cell+1, "code")
            tokens.append("NEWWWW")
            for token in r:
                tokens.append(token)

                reply += token
                update_cell(next_cell+1, reply)
#                 render_cell(next_cell+1)
            assert chat_client.stop_reason == "stop_sequence"
            update_cell(next_cell+1, reply)
            execute_cell(next_cell+1)
#             render_cell(next_cell+1)

            
    except BaseException as e:
        display(Markdown(f"🚫 {repr(e)}"))
#     for i in range(5):
#         update_cell(idx+1, f"Hello!\n" * i, True)
#         render_cell(idx+1)
# #         sleep(0.1)
#     update_cell(idx, raw_cell.replace("%%fr", "#%%fr"))


ip = get_ipython()
ip.register_magic_function(fr_cell, 'cell', magic_name='fr')

# fr_cell()

# FriendLLy AI overlords

In [None]:
%%fr

Find the square root of pi