<a href="https://colab.research.google.com/github/antalvdb/mblm/blob/main/timbl_llm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Memory-based language modeling

Looking for an LLM that is relatively eco-friendly? MBLMs rely on CPUs. No GPUs or TPUs required. Training MBLMs is costly in terms of RAM, but not in terms of time or computing resources. Running an MBLM in autoregressive GPT-style mode also costs RAM, but still relies on CPUs and is reasonably fast as well, depending on the selected approximation of k-nearest neighbor classification.

In [9]:
from transformers import AutoTokenizer
import re
import time
import argparse
import sys
import ast

Installing `python3-timbl`, Python bindings for the TiMBL engine.

In [10]:
!pip install python3-timbl

import timbl



Downloading an example MBLM model (approx. 457MB) with git lfs. This may take a while.

In [11]:
!git clone https://github.com/antalvdb/mblm
%cd mblm
!git lfs pull -I chatbot-instruction-prompts_tok.l16r0.igtree.ibase
%cd

Cloning into 'mblm'...
remote: Enumerating objects: 93, done.[K
remote: Counting objects: 100% (93/93), done.[K
remote: Compressing objects: 100% (83/83), done.[K
remote: Total 93 (delta 48), reused 26 (delta 10), pack-reused 0 (from 0)[K
Receiving objects: 100% (93/93), 54.24 KiB | 1009.00 KiB/s, done.
Resolving deltas: 100% (48/48), done.
/root/mblm
/root


Setting a global verbosity level, and a function for logging

In [12]:
# Global verbosity level
VERBOSITY = 1

In [13]:
def log(message, level=1):
    """Logs a message if the verbosity level is sufficient."""
    if VERBOSITY >= level:
        print(message)

Setting functions for padding the prompt with zeroes, and for generating text

In [14]:
def pad_prompt(words, max_len=16):
    """Pad or trim the list of words to make it exactly `max_len` words."""
    if words is None:
        words = []  # Ensure words is a list
    if len(words) < max_len:
        words = ['_'] * (max_len - len(words)) + words
    else:
        words = words[-max_len:]
    return words

In [15]:
def generate_text_from_api(classifier, initial_prompt, max_words=200):
    # Tokenize the initial prompt and convert tokens back to words
    initial_tokens = tokenizer.tokenize(initial_prompt)

    if initial_tokens is None:
        log("Tokenization failed; 'initial_tokens' is None.", level=1)
        initial_tokens = []

    # Prepare the initial prompt, padded or trimmed to 16 words
    padded_instances = []

    # Generate padded instances for next-token predictions
    for i in range(len(initial_tokens)):
        # Take the tokens up to the current position and pad them
        instance = pad_prompt(initial_tokens[:i], max_len=16)
        padded_instances.append((instance, initial_tokens[i] if i < len(initial_tokens) else '_'))

    # Add instances to memory
    for input_instance, next_token in padded_instances:
        log(f"memorized: {input_instance} {next_token}", level=2)
        classifier.append(input_instance, next_token)

    # Use the final part of the prompt for further generation
    prompt_words = pad_prompt(initial_tokens)

    generated_tokens = prompt_words[:]  # Store the full generated text

    try:
        # Loop until max words generated or a period token is found
        for _ in range(max_words):
            next_word = None

            classlabel, distribution, distance = classifier.classify(prompt_words)

            # Add instance to instance base
            classifier.append(prompt_words, classlabel)

            log(f"Prompt words: {prompt_words}", level=2)
            log(f"Classlabel: {classlabel}", level=2)
            log(f"Distribution: {distribution}", level=3)
            log(f"Distance: {distance}", level=3)

            generated_tokens.append(classlabel)

            # Shift prompt words and add the new word
            prompt_words = prompt_words[1:] + [classlabel]

            # Stop if a period is generated
            # if classlabel == ".":
            #     break

        # Detokenize the generated tokens
        generated_text = tokenizer.convert_tokens_to_string(generated_tokens)

        # Strip off original padding characters
        generated_text = generated_text.replace("_", "").strip()

        # Print the final generated text
        log(f"Generated text: {generated_text}", level=1)

    except Exception as e:
        log(f"Error: {e}", level=1)


In [None]:
if __name__ == "__main__":

    # Simulate command-line arguments for notebook environment
    sys.argv = ['script_name', '--classifier', '/content/mblm/chatbot-instruction-prompts-100k_tok.l16r0', '--tokenizer', 'bert-base-cased', '--timbl_args', '-a4 +D', '--verbosity', '3']

    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Memory-based text generator")
    parser.add_argument("--classifier", type=str, required=True, help="Path to the Timbl classifier file")
    parser.add_argument("--tokenizer", type=str, required=True, help="Name of the Hugging Face tokenizer")
    parser.add_argument("--timbl_args", type=str, required=True, help="Timbl arguments as a single string (e.g., '-a4 +D')")
    parser.add_argument("--verbosity", type=int, default=0, help="Verbosity level (0: silent, 1: basic, 2: detailed, 3: debug)")
    args = parser.parse_args()

    # Set global verbosity level
    VERBOSITY = args.verbosity

    # Initialize the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)

    # Initialize the classifier
    log("Loading TiMBL Classifier...", level=1)
    classifier = timbl.TimblClassifier(args.classifier, args.timbl_args)
    classifier.load()

    # Loop to continuously ask for input and classify
    while True:
        # Take input from the user
        user_input = input("Please enter prompt (or type 'exit' to quit): ")

        # Check if the user wants to exit
        if user_input.lower() == 'exit':
            log("Exiting.", level=1)
            break

        # Pass the input to the classifier function
        generate_text_from_api(classifier, user_input)
