In [1]:
!pip install pdfminer pdfminer.six sentence-transformers text-generation

Collecting pdfminer
  Downloading pdfminer-20191125.tar.gz (4.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.2/4.2 MB[0m [31m34.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pdfminer.six
  Downloading pdfminer.six-20221105-py3-none-any.whl (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m103.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting text-generation
  Downloading text_generation-0.6.0-py3-none-any.whl (10 kB)
Collecting pycryptodome (from pdfminer)
  Downloading pycryptodome-3.18.0-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [11]:
import argparse

from pdfminer.high_level import extract_text
from sentence_transformers import SentenceTransformer, CrossEncoder, util

from text_generation import Client

PREPROMPT = "Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.\n"
PROMPT = """"Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to
make up an answer. Don't make up new terms which are not available in the context.
{context}"""

END_7B = "\n<|prompter|>{query}<|endoftext|><|assistant|>"
END_40B = "\nUser: {query}\nFalcon:"

PARAMETERS = {
    "temperature": 0.9,
    "top_p": 0.95,
    "repetition_penalty": 1.2,
    "top_k": 50,
    "truncate": 1000,
    "max_new_tokens": 1024,
    "seed": 42,
    "stop_sequences": ["<|endoftext|>", "</s>"],
}
CLIENT_7B = Client("http://")  # Fill this part
CLIENT_40B = Client("https://")  # Fill this part


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--fname", type=str, required=True)
    parser.add_argument("--top-k", type=int, default=32)
    parser.add_argument("--window-size", type=int, default=128)
    parser.add_argument("--step-size", type=int, default=100)
    return parser.parse_args()


def embed(fname, window_size, step_size):
    text = extract_text(fname)
    text = " ".join(text.split())
    text_tokens = text.split()

    sentences = []
    for i in range(0, len(text_tokens), step_size):
        window = text_tokens[i : i + window_size]
        if len(window) < window_size:
            break
        sentences.append(window)

    paragraphs = [" ".join(s) for s in sentences]
    model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    model.max_seq_length = 512
    cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

    embeddings = model.encode(
        paragraphs,
        show_progress_bar=True,
        convert_to_tensor=True,
    )
    return model, cross_encoder, embeddings, paragraphs


def search(query, model, cross_encoder, embeddings, paragraphs, top_k):
    query_embeddings = model.encode(query, convert_to_tensor=True)
    query_embeddings = query_embeddings.cuda()
    hits = util.semantic_search(
        query_embeddings,
        embeddings,
        top_k=top_k,
    )[0]

    cross_input = [[query, paragraphs[hit["corpus_id"]]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_input)

    for idx in range(len(cross_scores)):
        hits[idx]["cross_score"] = cross_scores[idx]

    results = []
    hits = sorted(hits, key=lambda x: x["cross_score"], reverse=True)
    for hit in hits[:5]:
        results.append(paragraphs[hit["corpus_id"]].replace("\n", " "))
    return results


if __name__ == "__main__":
    args = parse_args()
    model, cross_encoder, embeddings, paragraphs = embed(
        args.fname,
        args.window_size,
        args.step_size,
    )
    print(embeddings.shape)
    while True:
        print("\n")
        query = input("Enter query: ")
        results = search(
            query,
            model,
            cross_encoder,
            embeddings,
            paragraphs,
            top_k=args.top_k,
        )

        query_7b = PREPROMPT + PROMPT.format(context="\n".join(results))
        query_7b += END_7B.format(query=query)

        query_40b = PREPROMPT + PROMPT.format(context="\n".join(results))
        query_40b += END_40B.format(query=query)

        text = ""
        for response in CLIENT_7B.generate_stream(query_7b, **PARAMETERS):
            if not response.token.special:
                text += response.token.text

        print("\n***7b response***")
        print(text)

        text = ""
        for response in CLIENT_40B.generate_stream(query_40b, **PARAMETERS):
            if not response.token.special:
                text += response.token.text

        print("\n***40b response***")
        print(text)

usage: ipykernel_launcher.py [-h] --fname FNAME [--top-k TOP_K]
                             [--window-size WINDOW_SIZE]
                             [--step-size STEP_SIZE]
ipykernel_launcher.py: error: the following arguments are required: --fname


SystemExit: ignored

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [2]:
!python chat.py --fname '/content/1706.03762.pdf'

Downloading (…)a8e1d/.gitattributes: 100% 1.18k/1.18k [00:00<00:00, 7.24MB/s]
Downloading (…)_Pooling/config.json: 100% 190/190 [00:00<00:00, 1.22MB/s]
Downloading (…)b20bca8e1d/README.md: 100% 10.6k/10.6k [00:00<00:00, 47.3MB/s]
Downloading (…)0bca8e1d/config.json: 100% 571/571 [00:00<00:00, 3.80MB/s]
Downloading (…)ce_transformers.json: 100% 116/116 [00:00<00:00, 767kB/s]
Downloading (…)e1d/data_config.json: 100% 39.3k/39.3k [00:00<00:00, 987kB/s]
Downloading pytorch_model.bin: 100% 438M/438M [00:01<00:00, 273MB/s]
Downloading (…)nce_bert_config.json: 100% 53.0/53.0 [00:00<00:00, 359kB/s]
Downloading (…)cial_tokens_map.json: 100% 239/239 [00:00<00:00, 1.39MB/s]
Downloading (…)a8e1d/tokenizer.json: 100% 466k/466k [00:00<00:00, 944kB/s]
Downloading (…)okenizer_config.json: 100% 363/363 [00:00<00:00, 2.21MB/s]
Downloading (…)8e1d/train_script.py: 100% 13.1k/13.1k [00:00<00:00, 31.0MB/s]
Downloading (…)b20bca8e1d/vocab.txt: 100% 232k/232k [00:00<00:00, 992kB/s]
Downloading (…)bca8e1d/mod