In [None]:
#| default_exp core

# dialoghelper

In [None]:
#| export
import inspect, json, importlib, linecache
from typing import Dict
from tempfile import TemporaryDirectory
from ipykernel_helper import *

from fastcore.utils import *
from fastcore.meta import delegates
from ghapi.all import *
from fastlite import *
from fastcore.xtras import asdict

In [None]:
#| export
_all_ = ["asdict"]

In [None]:
#| export
def get_db(ns:dict=None):
    app_path = Path('/app') if Path('/.dockerenv').exists() else Path('.')
    if os.environ.get('IN_SOLVEIT', False): dataparent,nm = app_path, 'data.db'
    else: dataparent,nm = Path('..'),'dev_data.db'
    db = database(dataparent/'data'/nm)
    dcs = [o for o in all_dcs(db) if o.__name__[0]!='_']
    if ns:
        for o in dcs: ns[o.__name__]=o
    return db

In [None]:
db = get_db(globals())
dlg = db.t.dialog.fetchone()
dlg

In [None]:
#| export
def find_var(var:str):
    "Search for var in all frames of the call stack"
    frame = inspect.currentframe()
    while frame:
        dv = frame.f_globals.get(var, frame.f_locals.get(var, None))
        if dv: return dv
        frame = frame.f_back
    raise ValueError(f"Could not find {var} in any scope")

In [None]:
a = 1
find_var('a')

In [None]:
#| export
def find_dialog_id():
    "Get the dialog id by searching the call stack for __dialog_id."
    return find_var('__dialog_id')

In [None]:
__dialog_id = dlg.id

In [None]:
find_dialog_id()

In [None]:
#| export
def find_msgs(
    pattern: str, # Text to search for
    limit:int=10 # Limit number of returned items
):
    "Find messages in a specific dialog that contain the given pattern."
    did = find_dialog_id()
    db = get_db()
    res = db.t.message('did=? AND content LIKE ? ORDER BY mid', [did, f'%{pattern}%'], limit=limit)
    return [asdict(o) for o in res]

In [None]:
found = find_msgs('to the')
found[0]

In [None]:
#| export
def find_msg_id():
    "Get the message id by searching the call stack for __dialog_id."
    return find_var('__msg_id')

In [None]:
__msg_id = found[0].sid

In [None]:
find_msg_id()

In [None]:
#| export
def read_msg_ids():
    "Get all ids in current dialog."
    did = find_dialog_id()
    db = get_db()
    return [o.sid for o in db.t.message('did=?', [did], select='sid', order_by='mid')]

In [None]:
#| export
def msg_idx():
    "Get index of current message in dialog."
    ids = read_msg_ids()
    return ids,ids.index(find_msg_id())

In [None]:
ids,idx = msg_idx()
idx

In [None]:
#| export
def read_msg(n:int=-1,     # Message index (if relative, +ve is downwards)
             relative:bool=True  # Is `n` relative to current message (True) or absolute (False)?
    ):
    "Get the message indexed in the current dialog."
    ids,idx = msg_idx()
    if relative:
        idx = idx+n
        if not 0<=idx<len(ids): return None
    else: idx = n
    db = get_db()
    return db.t.message.fetchone('sid=?', [ids[idx]])

In [None]:
# Previous message relative to current
read_msg(-1)

In [None]:
# Last message in dialog
read_msg(-1, relative=False)

In [None]:
#| export
def _msg(
    input_tokens: int | None = 0,
    output_tokens: int | None = 0,
    time_run: str | None = '',
    is_exported: int | None = 0,
    skipped: int | None = 0,
    did: int | None = None,
    i_collapsed: int | None = 0,
    o_collapsed: int | None = 0,
    header_collapsed: int | None = 0,
    pinned: int | None = 0
): ...

In [None]:
#| export
@delegates(_msg)
def add_msg(
    content:str, # content of the message (i.e the message prompt, code, or note text)
    msg_type: str='note', # message type, can be 'code', 'note', or 'prompt'
    output:str='', # for prompts/code, initial output
    placement:str='add_after', # can be 'add_after', 'add_before', 'update', 'at_start', 'at_end'
    sid:str=None, # id of message that placement is relative to (if None, uses current message)
    **kwargs # additional Message fields such as skipped i/o_collapsed, etc, passed through to the server
):
    "Add/update a message to the queue to show after code execution completes."
    assert msg_type in ('note', 'code', 'prompt'), "msg_type must be 'code', 'note', or 'prompt'."
    assert msg_type not in ('note') or not output, "'note' messages cannot have an output."
    run_cmd('add_msg', content=content, msg_type=msg_type, output=output, placement=placement, sid=sid, **kwargs)

In [None]:
#| export
@delegates(add_msg)
def update_msg(
    msg:Optional[Dict]=None, # Dictionary of field keys/values to update
    sid:str=None, # id of message that placement is relative to (if None, uses current message)
    content:str=None, # content of the message (i.e the message prompt, code, or note text)
    **kwargs):
    "Update an existing message. Provide either `msg` OR field key/values to update. Use `content` param to update contents."
    if content: kwargs['content']=content
    assert bool(msg)^bool(kwargs), "Provide *either* msg, for kwargs, not both"
    if msg and 'sid' in msg: target_id = msg['sid']
    elif sid: target_id = sid
    else: raise TypeError("update_msg needs either a dict message or `sid=...`")
    old = asdict(get_db().t.message[target_id])
    kw = old | (msg or {}) | kwargs
    kw.pop('did', None)
    return add_msg(placement='update', **kw)

In [None]:
#| export
def add_html(
    html:str, # HTML to add to the DOM
):
    "Dynamically add HTML to the current web page. Supports HTMX attrs too."
    run_cmd('add_ft', html=html)

In [None]:
#| export
def load_gist(gist_id:str):
    "Retrieve a gist"
    api = GhApi()
    if '/' in gist_id: *_,user,gist_id = gist_id.split('/')
    else: user = None
    return api.gists.get(gist_id, user=user)

In [None]:
gistid = 'jph00/e7cfd4ded593e8ef6217e78a0131960c'
gist = load_gist(gistid)
gist.html_url

In [None]:
#| export
def gist_file(gist_id:str):
    "Get the first file from a gist"
    gist = load_gist(gist_id)
    return first(gist.files.values())

In [None]:
gfile = gist_file(gistid)
print(gfile.content)

In [None]:
#| export
def import_string(
    code:str, # Code to import as a module
    name:str  # Name of module to create
):
    with TemporaryDirectory() as tmpdir:
        path = Path(tmpdir) / f"{name}.py"
        path.write_text(code)
        # linecache.cache storage allows inspect.getsource() after tmpdir lifetime ends
        linecache.cache[str(path)] = (len(code), None, code.splitlines(keepends=True), str(path))
        spec = importlib.util.spec_from_file_location(name, path)
        module = importlib.util.module_from_spec(spec)
        sys.modules[name] = module
        spec.loader.exec_module(module)
        return module

In [None]:
#| export
def import_gist(
    gist_id:str, # user/id or just id of gist to import as a module
    mod_name:str=None, # module name to create (taken from gist filename if not passed)
    add_global:bool=True # add module to caller's globals?
):
    "Import gist directly from string without saving to disk"
    fil = gist_file(gist_id)
    mod_name = mod_name or Path(fil['filename']).stem
    module = import_string(fil['content'], mod_name)
    if add_global: inspect.currentframe().f_back.f_globals[mod_name] = module
    return module

In [None]:
import_gist(gistid)
importtest.testfoo

In [None]:
#| export
def import_tools_gist(gist_id: str):
    """Import tools and prompt from a gist into the current dialog.

    The gist should contain:
    - A section marked with #%% imports
    - A section marked with #%% tools 
    - A section marked with #%% prompt
    """
    import ast
    
    # Get gist content
    gist = gist_file(gist_id)
    content = gist['content']

    # Split content into sections
    sections = {}
    current_section = None
    current_content = []

    for line in content.splitlines():
        if line.startswith('#%%'):
            if current_section:
                sections[current_section] = '\n'.join(current_content)
            current_section = line.replace('#%%', '').strip()
            current_content = []
        else:
            current_content.append(line)

    if current_section:
        sections[current_section] = '\n'.join(current_content)

    # Add prompt as pinned message at the start
    if 'prompt' in sections:
        add_msg(
            content=sections['prompt'].strip(),
            msg_type='note',
            placement='at_start',
            pinned=1
        )

    # Add imports as first code message
    if 'imports' in sections:
        add_msg(
            content=sections['imports'].strip(),
            msg_type='code',
            placement='at_start'
        )

    # Add tools as second code message
    if 'tools' in sections:
        tools_code = sections['tools'].strip()
        add_msg(
            content=tools_code,
            msg_type='code',
            placement='at_end'
        )
        
        # Parse function names using ast
        tree = ast.parse(tools_code)
        func_names = [
            node.name 
            for node in ast.walk(tree) 
            if isinstance(node, ast.FunctionDef) and not node.name.startswith('_')
        ]
        
        # Add final prompt with function list
        add_msg(
            content=f"&`[{', '.join(func_names)}]`",
            msg_type='prompt',
            placement='at_end'
        )

## export -

In [None]:
#|hide
from nbdev import nbdev_export
nbdev_export()