<a href="https://colab.research.google.com/github/antalvdb/mblm/blob/main/timbl_llm_hf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Simple integration of MBLM in Hugging Face

This notebook introduces the `TimblHuggingFaceModel` class and shows how a Memory-Based (or TiMBL-based) language model can be called Hugging Face-style.

In [1]:
import transformers
import re
import time
import argparse
import sys
import ast

The package to be installed is `python3-timbl`.

In [2]:
!pip install python3-timbl

import timbl

Collecting python3-timbl
  Downloading python3_timbl-2025.1.22-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (824 bytes)
Downloading python3_timbl-2025.1.22-cp311-cp311-manylinux_2_28_x86_64.whl (21.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.7/21.7 MB[0m [31m22.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: python3-timbl
Successfully installed python3-timbl-2025.1.22


We are going to load a TiMBL model from github (for now - downloading from Hugging Face is up for implementation). Note that an lfs pull is necessary as we are going to load a model of 0.5 GB.

In [3]:
!git clone https://github.com/antalvdb/mblm
%cd mblm
!git lfs pull -I chatbot-instruction-prompts_tok.l16r0.igtree.ibase
%cd

Cloning into 'mblm'...
remote: Enumerating objects: 111, done.[K
remote: Counting objects: 100% (111/111), done.[K
remote: Compressing objects: 100% (101/101), done.[K
remote: Total 111 (delta 59), reused 26 (delta 10), pack-reused 0 (from 0)[K
Receiving objects: 100% (111/111), 105.28 KiB | 7.52 MiB/s, done.
Resolving deltas: 100% (59/59), done.
/content/mblm
/root


Introducing the `TimblHuggingFaceModel` class. The main functions are `forward`, where the TiMBL classifier is queried and produces a distribution of next tokens, and `convert_to_timbl_input` and `convert_to_huggingface_logits` which bridge the two representation systems.

In [4]:
import torch
from transformers import AutoConfig, AutoTokenizer, PreTrainedModel

class TimblHuggingFaceModel(PreTrainedModel):

    # Define a function to replace values with actual floats
    def float_converter(match):
        return f"{match.group(1)}: {float(match.group(2))}"

    def __init__(self, config, timbl_classifier, tokenizer):
        super().__init__(config)
        self.timbl_classifier = timbl_classifier
        self.tokenizer = tokenizer  # Store tokenizer

    def forward(self, input_ids, **kwargs):

        #print("inside forward")

        # Convert input_ids to Timbl format
        timbl_input = self.convert_to_timbl_input(input_ids)
        log(f"Timbl input: {timbl_input}",level=3)

        # Get Timbl predictions
        classlabel, distribution, distance = self.timbl_classifier.classify(timbl_input)
        log(f"Classlabel: {classlabel}", level = 3)
        log(f"Distribution: {distribution}", level = 3)
        log(f"Distance: {distance}", level = 3
)
        # Convert Timbl output to Hugging Face format
        logits = self.convert_to_huggingface_logits(distribution)
        log(f"Logits: {logits}", level = 3)

        # Return logits and other relevant outputs
        return transformers.modeling_outputs.CausalLMOutputWithCrossAttentions(logits=logits)

    def convert_to_timbl_input(self, input_ids):

        #print("inside convert_to_timbl_input")

        """Converts Hugging Face input_ids to Timbl input format."""
        # Decode input_ids to a string of tokens
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids.squeeze(0))
        log(f"Tokens: {tokens}", level = 3)

        # Return the array of tokens directly
        return tokens

    def convert_to_huggingface_logits(self, distribution):

        #print("inside convert_to_huggingface_logits")

        # Bypassing the typical HuggingFace device setting and passing
        device = "cpu"

        # Get vocabulary size from the tokenizer
        vocab_size = self.tokenizer.vocab_size

        # Initialize logits with a default value (e.g., -inf)
        logits = torch.full((1, vocab_size), float('-inf'), device=device)
        log(f"Logits: {logits}",level=3)

        # Fill logits with probabilities from the Timbl distribution
        for word, probability in distribution.items():
            hf_token_id = self.tokenizer.convert_tokens_to_ids(word)
            log(f"hf_token_id: {hf_token_id}", level = 4)

            # Check if hf_token_id is a list and take the first element if it is
            # Handling nested lists as well
            while isinstance(hf_token_id, list) and len(hf_token_id) > 0:
                hf_token_id = hf_token_id[0]
                log(f"hf_token_id: {hf_token_id}", level = 4)

            if isinstance(hf_token_id, int):  # Ensure it's now an integer
                try:
                    logits[0, hf_token_id] = torch.tensor(probability, device=device)
                    log(f"logits[0], hf_token_id]:  {logits[0, hf_token_id]} ", level = 4)
                    log(f"Logits shape: {logits.shape}", level = 4)
                except IndexError:
                    # Handle the case where hf_token_id is out of bounds
                    log(f"Warning: Token ID {hf_token_id} is out of bounds for logits shape {logits.shape}", level=1)
            else:
                log(f"Warning: Skipping word '{word}' due to unexpected token ID format: {hf_token_id}", level=1)

        return logits

Set a global verbosity level. A logging function `log()` sets different verbosity levels for triggering different outputs.

In [5]:
# Global verbosity level
VERBOSITY = 3

In [6]:
def log(message, level=1):
    """Logs a message if the verbosity level is sufficient."""
    if VERBOSITY >= level:
        print(message)

The function `pad_prompt` initializes the input window for MBLM.

In [7]:
def pad_prompt(words, max_len=16):
    """Pad or trim the list of words to make it exactly `max_len` words."""
    if words is None:
        words = []  # Ensure words is a list
    if len(words) < max_len:
        words = ['_'] * (max_len - len(words)) + words
    else:
        words = words[-max_len:]
    return words

The function `generate_text_from_api` wraps around MBLM to generate a full text based on a prompt.

In [8]:
def generate_text_from_api(model, tokenizer, initial_prompt, max_words=200):
    # Tokenize the initial prompt and convert tokens back to words
    initial_tokens = tokenizer.tokenize(initial_prompt)

    if initial_tokens is None:
        log("Tokenization failed; 'initial_tokens' is None.", level=1)
        initial_tokens = []

    # Prepare the initial prompt, padded or trimmed to 16 words
    padded_instances = []

    # Generate padded instances for next-token predictions
    for i in range(len(initial_tokens)):
        # Take the tokens up to the current position and pad them
        instance = pad_prompt(initial_tokens[:i], max_len=16)
        padded_instances.append((instance, initial_tokens[i] if i < len(initial_tokens) else '_'))

    # Add instances to memory using the TimblHuggingFaceModel (if applicable)
    # This part might need to be adapted depending on how memory is handled in TimblHuggingFaceModel
    for input_instance, next_token in padded_instances:
        log(f"memorized: {input_instance} {next_token}", level=2)
        # Adapt this line to call the appropriate method of TimblHuggingFaceModel for adding to memory
        # model.append(input_instance, next_token)

    # Use the final part of the prompt for further generation
    prompt_words = pad_prompt(initial_tokens)

    generated_tokens = prompt_words[:]  # Store the full generated text

    try:
        # Loop until max words generated or a period token is found
        for _ in range(max_words):
            next_word = None

            # Get prediction from TimblHuggingFaceModel
            encoded_input = tokenizer.encode(" ".join(prompt_words), return_tensors="pt", padding="max_length", truncation=True, max_length=16, add_special_tokens=False)
            log(f"encoded_input: {encoded_input}", level = 3)
            input_ids = encoded_input[0]  # Access the first sequence
            log(f"input_ids: {input_ids}", level = 3)
            outputs = model(input_ids) # get model output
            log(f"outputs: {outputs}", level = 3)
            logits = outputs.logits # extract logits
            log(f"logits: {logits}", level = 3)

            # Get the predicted token ID
            # Get the predicted token ID, excluding the [CLS] token
            predicted_token_id = torch.argmax(logits[:, 1:], dim=-1).item() + 1  # Add 1 to shift back to original index
            #predicted_token_id = torch.argmax(logits, dim=-1).item()

            # Decode the token ID to a word
            classlabel = tokenizer.decode(predicted_token_id)

            # Add instance to instance base (if applicable)
            # Adapt this line to call the appropriate method of TimblHuggingFaceModel for adding to memory
            # model.append(prompt_words, classlabel)

            log(f"Prompt words: {prompt_words}", level=2)
            log(f"Classlabel: {classlabel}", level=2)

            generated_tokens.append(classlabel)

            # Shift prompt words and add the new word
            prompt_words = prompt_words[1:] + [classlabel]

        # Detokenize the generated tokens
        generated_text = tokenizer.convert_tokens_to_string(generated_tokens)

        # Strip off original padding characters
        generated_text = generated_text.replace("_", "").strip()

        # Print the final generated text
        log(f"Generated text: {generated_text}", level=1)

    except Exception as e:
        log(f"Error: {e}", level=1)

The main code block initializes the MBLM model and the accompanying tokenizer, and runs a loop in which the user is invited to give prompts, triggering the generation of text.

In [None]:
if __name__ == "__main__":

    # Simulate command-line arguments for notebook environment
#    sys.argv = ['script_name', '--classifier', '/content/drive/MyDrive/dolly-15k-dutch-train_tok.l16r0', '--timbl_args', '-a4 +D', '--verbosity', '3']
    sys.argv = ['script_name', '--classifier', '/content/mblm/chatbot-instruction-prompts_tok.l16r0.igtree', '--tokenizer', 'bert-base-cased', '--timbl_args', '-a1 +D', '--verbosity', '1']

    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Memory-based text generator")
    parser.add_argument("--classifier", type=str, required=True, help="Path to the Timbl classifier file")
    parser.add_argument("--tokenizer", type=str, required=True, help="Name of the Hugging Face tokenizer")
    parser.add_argument("--timbl_args", type=str, required=True, help="Timbl arguments as a single string (e.g., '-a4 +D')")
    parser.add_argument("--verbosity", type=int, default=0, help="Verbosity level (0: silent, 1: basic, 2: detailed, 3: debug)")
    args = parser.parse_args()

    # Set global verbosity level
    VERBOSITY = args.verbosity

    # Initialize the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
    log(f"Tokenizer {tokenizer} loaded", level=1)

    # Initialize the Timbl classifier
    classifier = timbl.TimblClassifier(args.classifier, args.timbl_args)
    classifier.load()
    log("TiMBL classifier loaded", level=1)

    config = AutoConfig.from_pretrained("antalvdb/mblm-chatbot-instruction-prompts-igtree")
    tokenizer.add_special_tokens({'pad_token': '_'})
    tokenizer.pad_token = "_"
    #print(config)

    # Initialize the TimblHuggingFaceModel
    model = TimblHuggingFaceModel(config, classifier, tokenizer)  # Pass tokenizer

    # Single prompt test:
    # user_input = input("Please enter prompt: ")
    # generate_text_from_api(model, tokenizer, user_input)

    # Loop to continuously ask for input and classify
    while True:
        # Take input from the user
        user_input = input("Please enter prompt (or type 'exit' to quit): ")

        # Check if the user wants to exit
        if user_input.lower() == 'exit':
            log("Exiting.", level=1)
            break

        # Pass the input to the classifier function
        generate_text_from_api(model, tokenizer, user_input)



args.verbosity: 1
VERBOSITY: 1


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

Tokenizer BertTokenizerFast(name_or_path='bert-base-cased', vocab_size=28996, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
) loaded


Calling Timbl API : -F Tabbed -a1 +D


TiMBL classifier loaded


config.json:   0%|          | 0.00/262 [00:00<?, ?B/s]