<a href="https://colab.research.google.com/github/antalvdb/mblm/blob/main/timbl_llm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import re
import time
import argparse
import sys
import ast

In [None]:
!pip install python3-timbl

import timbl

Collecting python3-timbl
  Downloading python3_timbl-2025.1.22-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (824 bytes)
Downloading python3_timbl-2025.1.22-cp311-cp311-manylinux_2_28_x86_64.whl (21.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.7/21.7 MB[0m [31m66.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: python3-timbl
Successfully installed python3-timbl-2025.1.22


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [13]:
# Global verbosity level
VERBOSITY = 1

In [None]:
# Load the tokenizer
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('GroNLP/bert-base-dutch-cased')

tokenizer_config.json:   0%|          | 0.00/254 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/242k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [None]:
def log(message, level=1):
    """Logs a message if the verbosity level is sufficient."""
    if VERBOSITY >= level:
        print(message)

In [None]:
def pad_prompt(words, max_len=16):
    """Pad or trim the list of words to make it exactly `max_len` words."""
    if words is None:
        words = []  # Ensure words is a list
    if len(words) < max_len:
        words = ['_'] * (max_len - len(words)) + words
    else:
        words = words[-max_len:]
    return words

In [None]:
def generate_text_from_api(classifier, initial_prompt, max_words=200):
    # Tokenize the initial prompt and convert tokens back to words
    initial_tokens = tokenizer.tokenize(initial_prompt)

    if initial_tokens is None:
        log("Tokenization failed; 'initial_tokens' is None.", level=1)
        initial_tokens = []

    # Prepare the initial prompt, padded or trimmed to 16 words
    padded_instances = []

    # Generate padded instances for next-token predictions
    for i in range(len(initial_tokens)):
        # Take the tokens up to the current position and pad them
        instance = pad_prompt(initial_tokens[:i], max_len=16)
        padded_instances.append((instance, initial_tokens[i] if i < len(initial_tokens) else '_'))

    # Add instances to memory
    for input_instance, next_token in padded_instances:
        log(f"memorized: {input_instance} {next_token}", level=2)
        classifier.append(input_instance, next_token)

    # Use the final part of the prompt for further generation
    prompt_words = pad_prompt(initial_tokens)

    generated_tokens = prompt_words[:]  # Store the full generated text

    try:
        # Loop until max words generated or a period token is found
        for _ in range(max_words):
            next_word = None

            classlabel, distribution, distance = classifier.classify(prompt_words)

            # Add instance to instance base
            classifier.append(prompt_words, classlabel)

            log(f"Prompt words: {prompt_words}", level=2)
            log(f"Classlabel: {classlabel}", level=2)
            log(f"Distribution: {distribution}", level=3)
            log(f"Distance: {distance}", level=3)

            generated_tokens.append(classlabel)

            # Shift prompt words and add the new word
            prompt_words = prompt_words[1:] + [classlabel]

            # Stop if a period is generated
            # if classlabel == ".":
            #     break

        # Detokenize the generated tokens
        generated_text = tokenizer.convert_tokens_to_string(generated_tokens)

        # Strip off original padding characters
        generated_text = generated_text.replace("_", "").strip()

        # Print the final generated text
        log(f"Generated text: {generated_text}", level=1)

    except Exception as e:
        log(f"Error: {e}", level=1)


In [None]:
if __name__ == "__main__":

    # Simulate command-line arguments for notebook environment
    sys.argv = ['script_name', '--classifier', '/content/drive/MyDrive/dolly-15k-dutch-train_tok.l16r0', '--timbl_args', '-a4 +D', '--verbosity', '3']

    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Memory-based text generator")
    parser.add_argument("--classifier", type=str, required=True, help="Path to the Timbl classifier file")
    parser.add_argument("--timbl_args", type=str, required=True, help="Timbl arguments as a single string (e.g., '-a4 +D')")
    parser.add_argument("--verbosity", type=int, default=0, help="Verbosity level (0: silent, 1: basic, 2: detailed, 3: debug)")
    args = parser.parse_args()

    # Set global verbosity level
    VERBOSITY = args.verbosity

    # Initialize the classifier
    classifier = timbl.TimblClassifier(args.classifier, args.timbl_args)
    classifier.load()

    # Loop to continuously ask for input and classify
    while True:
        # Take input from the user
        user_input = input("Please enter prompt (or type 'exit' to quit): ")

        # Check if the user wants to exit
        if user_input.lower() == 'exit':
            log("Exiting.", level=1)
            break

        # Pass the input to the classifier function
        generate_text_from_api(classifier, user_input)
