In [None]:
#|default_exp rag

# ShellSage Retrieval Augmented Generation

## Imports

In [None]:
#| export
from chonkie import SentenceChunker
from fastcore.all import *
from lancedb import connect
from lancedb.pydantic import LanceModel, Vector
from pathlib import Path
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from subprocess import check_output as co

import os, re, subprocess
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [None]:
import random

## Database setup

In [None]:
#| export
# set up db path in user's home cache directory
db_path = Path.home() / '.cache' / 'shell_sage' / 'db'
db_path.mkdir(parents=True, exist_ok=True)

In [None]:
#| export
chunker = SentenceChunker(tokenizer="gpt2", chunk_size=2_048,
                          chunk_overlap=256, min_sentences_per_chunk=1)

In [None]:
#| export
static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-retrieval-32M")
model = SentenceTransformer(modules=[static_embedding], device="cpu")
ndim = model.encode(["Example sentence"]).shape[1]

In [None]:
ndim

512

In [None]:
#| export
class EmbeddingTable(LanceModel):
    content: str
    package_name: str
    embedding: Vector(ndim)

In [None]:
#| export
db = connect(db_path)
tbl = None
try: tbl = db.open_table("man_pages")
except ValueError:
    tbl = db.create_table("man_pages", schema=EmbeddingTable, mode="create")
    tbl.create_fts_index("content") # for hybrid search

In [None]:
db_path.ls()

(#1) [Path('/Users/nathan/.cache/shell_sage/db/man_pages.lance')]

## Man page indexing

In [None]:
#| export
def _get_page(cmd):
    try: return cmd, co(['man', cmd], text=True,
                   stderr=subprocess.DEVNULL).strip()
    except: return cmd, None

def clean(text: str) -> str:
    "Cleans man page text by removing backspaces and ANSI escape codes"

    # Remove backspaces
    while True:
        cleaned = re.sub(r'.\x08', '', text)
        if cleaned == text:
            break
        text = cleaned

    # Remove ANSI escape codes
    text = re.sub(r'(?:\x1B[@-Z\\-_]|\x1B\[?[\d;]*[A-PRZcf-nq-uy=><])', '', text)

    # Convert multiple newlines into one
    text = re.sub(r'\n+', '\n', text)
    return text.strip()

def _man_pages(lim=None):
    lines = L(co(['man', '-k', '.'], text=True).strip().splitlines())
    lines = lines.filter(lambda s: re.search(r'\(1[^\)]*\)', s))
    cmds = (lines.map(lambda s: s.split("(")[0].strip())
                 .filter(lambda s: s).unique()[:lim])

    pages = parallel(_get_page, cmds).filter(lambda x: x[1])
    return zip(*pages.map(lambda x: (x[0], clean(x[1]))))


In [None]:
cmds, pages = _man_pages(128)

In [None]:
idx = random.randint(0, len(pages))
cmd, page = cmds[idx], pages[idx]
print(cmd)
print(page[:128])

git-credential-cache--daemon
GIT-CREDENTIAL-CAC(1)             Git Manual             GIT-CREDENTIAL-CAC(1)
NAME
       git-credential-cache--daemon - Tempor


In [None]:
#| export
def index_man_pages(cmds, pages):
    for cmd, chunks in zip(cmds, chunker.chunk_batch(pages)):
        embds = model.encode([chunk.text for chunk in chunks])
        tbl.add([EmbeddingTable(content=c.text, package_name=cmd, embedding=emb)
                 for c, emb in zip(chunks, embds)])

In [None]:
index_man_pages(cmds, pages)

🦛 choooooooooooooooooooonk 100% • 127/127 docs chunked [00:01<00:00, 106.70doc/s] 🌱


In [None]:
#| export
def search(q: str, limit: int=2, threshold: float=0.5):
    q_emb = model.encode([q])
    df = tbl.search(q_emb).metric("cosine").limit(limit).to_pandas()
    df = df.rename(columns={"_distance": "cosine_distance"})
    df = df[df.cosine_distance < 1 - threshold]
    return df

In [None]:
df = search("How can I change my current branch?")
df.head()

Unnamed: 0,content,package_name,embedding,cosine_distance
0,Both the index and working tree are...,git-switch,"[-0.3973836, 1.8147281, 0.0032745493, 1.281918...",0.43894
1,GIT-SWITCH(1) Git Manual ...,git-switch,"[0.0898071, 1.866398, 0.020456946, 1.2332366, ...",0.447789


## CLI

In [None]:
#| export
@call_parse
def main():
    "Index man pages for RAG"
    cmds, pages = _man_pages()
    index_man_pages(cmds, pages)