<a href="https://colab.research.google.com/github/antalvdb/olifant/blob/main/timbl_llm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Memory-based language modeling with Olifant

Looking for an LLM that is relatively eco-friendly? Memory-based language models rely on CPUs; no GPUs or TPUs are required. Training MBLMs is costly in terms of RAM, but not in terms of time or computing resources. Running an MBLM in autoregressive GPT-style mode also costs RAM, but still relies on CPUs and is reasonably fast as well, depending on the selected approximation of k-nearest neighbor classification.


In this notebook we work with the `olifant` package, which installs all necessary components, such as the TiMBL engine. We also make use of the Hugging Face transformers library.

In [None]:
!pip install olifant

import timbl
from transformers import AutoTokenizer, AutoModel
import re
import time
import argparse
import sys
import ast


Connect to Google Drive if you have files there that you want to use in this Notebook

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Or, download a sample MBLM model



In [None]:
!wget https://antalvandenbosch.nl/mblm/edufineweb_train_000001.tok.l4r0.igtree.ibase
!wget https://antalvandenbosch.nl/mblm/edufineweb_train_000001.tok.l4r0.igtree.ibase.wgt

We define a global verbosity level and introduce a simple logging function

In [None]:
# Global verbosity level
VERBOSITY = 3

In [None]:
def log(message, level=1):
    """Logs a message if the verbosity level is sufficient."""
    if VERBOSITY >= level:
        print(message)

The following two functions do the essential work. The def `generate_text_from_prompt` will call the Olifant classifier iteratively to generate an output text autoregressively based on a prompt given by the user. A text is generated up to a maximum number of words (200) or until a period is generated.

In [None]:
def pad_prompt(words, max_len=4):
    """Pad or trim the list of words to make it exactly `max_len` words."""
    if words is None:
        words = []  # Ensure words is a list
    if len(words) < max_len:
        words = ['_'] * (max_len - len(words)) + words
    else:
        words = words[-max_len:]
    return words

In [None]:
def generate_text_from_prompt(classifier, initial_prompt, max_words=200):
    # Tokenize the initial prompt and convert tokens back to words
    initial_tokens = tokenizer.tokenize(initial_prompt)

    if initial_tokens is None:
        log("Tokenization failed; 'initial_tokens' is None.", level=1)
        initial_tokens = []

    # Prepare the initial prompt, padded or trimmed to 4 words
    padded_instances = []

    # Generate padded instances for next-token predictions
    for i in range(len(initial_tokens)):
        # Take the tokens up to the current position and pad them
        instance = pad_prompt(initial_tokens[:i], max_len=4)
        padded_instances.append((instance, initial_tokens[i] if i < len(initial_tokens) else '_'))

    # Add instances to memory
    # for input_instance, next_token in padded_instances:
    #     log(f"memorized: {input_instance} {next_token}", level=2)
    #     classifier.append(input_instance, next_token)

    # Use the final part of the prompt for further generation
    prompt_words = pad_prompt(initial_tokens)

    # generated_tokens = prompt_words[:]  # Start with the prompt
    generated_tokens = [] # Or, start empty

    try:
        # Loop until max words generated or a period token is found
        for _ in range(max_words):
            next_word = None

            classlabel, distribution, distance = classifier.classify(prompt_words)

            neighbors_output = classifier.bestNeighbours()

            log(f"Neighbors output: {neighbors_output}", level=3)


            # Add instance to instance base
            classifier.append(prompt_words, classlabel)

            log(f"Prompt words: {prompt_words}", level=2)
            log(f"Classlabel: {classlabel}", level=2)
            log(f"Distribution: {distribution}", level=3)
            log(f"Distance: {distance}", level=3)

            generated_tokens.append(classlabel)

            # Shift prompt words and add the new word
            prompt_words = prompt_words[1:] + [classlabel]

            # Stop if a period is generated
            if classlabel == ".":
                break

        # Detokenize the generated tokens
        generated_text = tokenizer.convert_tokens_to_string(generated_tokens)

        # Strip off original padding characters
        generated_text = generated_text.replace("_", "").strip()

        # Print the final generated text
        log(f"Generated text: {generated_text}", level=0)

    except Exception as e:
        log(f"Error: {e}", level=1)


This function simulates a command-line call, taking in four arguments:

1. classifier,
2. tokenizer,
3. TiMBL arguments,
4. verbosity level.

When this cell is run, it might take a while to load the TiMBL classifier into memory (in Colab, check RAM in resources to see if the model fits). When prompted, just try some natural language text as input. When done, type 'exit'.

In [None]:
if __name__ == "__main__":

    # Simulate command-line arguments for notebook environment
    # sys.argv = ['olifant', '--classifier', '/content/mblm/chatbot-instruction-prompts_tok.l16r0.igtree', '--tokenizer', 'bert-base-cased', '--timbl_args', '-a1 +D', '--verbosity', '3']
    sys.argv = ['olifant', '--classifier', 'edufineweb_train_000001.tok.l4r0.igtree', '--tokenizer', 'gpt2', '--timbl_args', '-a1 +D', '--verbosity', '0']

    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Memory-based text generator")
    parser.add_argument("--classifier", type=str, required=True, help="Path to the Timbl classifier file")
    parser.add_argument("--tokenizer", type=str, required=True, help="Name of the Hugging Face tokenizer")
    parser.add_argument("--timbl_args", type=str, required=True, help="Timbl arguments as a single string (e.g., '-a4 +D')")
    parser.add_argument("--verbosity", type=int, default=0, help="Verbosity level (0: silent, 1: basic, 2: detailed, 3: debug)")
    args = parser.parse_args()

    # Set global verbosity level
    VERBOSITY = args.verbosity

    # Initialize the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)

    # Initialize the classifier
    log("Loading TiMBL Classifier...", level=1)
    classifier = timbl.TimblClassifier(args.classifier, args.timbl_args)
    classifier.load()

    # Loop to continuously ask for input and classify
    while True:
        # Take input from the user
        user_input = input("Please enter prompt (or type 'exit' to quit): ")

        # Check if the user wants to exit
        if user_input.lower() == 'exit':
            log("Exiting.", level=1)
            break

        # Pass the input to the classifier function
        generate_text_from_prompt(classifier, user_input)
