<a href="https://colab.research.google.com/github/antalvdb/mblm/blob/main/timbl_llm_server.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
from transformers import AutoTokenizer
import re
import time
import argparse
import sys

In [None]:
!pip install python3-timbl

import timbl

In [None]:
# Global verbosity level
VERBOSITY = 1

In [None]:
def log(message, level=1):
    """Logs a message if the verbosity level is sufficient."""
    if VERBOSITY >= level:
        print(message)

In [None]:
def pad_prompt(words, max_len=16):
    """Pad or trim the list of words to make it exactly `max_len` words."""
    if words is None:
        words = []  # Ensure words is a list
    if len(words) < max_len:
        words = ['_'] * (max_len - len(words)) + words
    else:
        words = words[-max_len:]
    return words

In [None]:
import socket
import time
import re

def generate_text_from_server(host, port, initial_prompt, max_words=200):

    # Tokenize the initial prompt and convert tokens back to words
    initial_tokens = tokenizer.tokenize(initial_prompt)

    if initial_tokens is None:
        print("Tokenization failed; 'initial_tokens' is None.")
        initial_tokens = []


    # Prepare the initial prompt, padded or trimmed to 8 words
    prompt_words = pad_prompt(initial_tokens)

    generated_tokens = prompt_words[:]  # Store the full generated text

    # Print the initial prompt tokens
    log(f"Initial prompt tokens: {' '.join(initial_tokens)}", level=3)


    try:
        # Create a socket connection
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client_socket:
            # Connect to the server
            client_socket.connect((host, port))

            # Receive the initial "Welcome" message from the server
            welcome_message = client_socket.recv(1024).decode('utf-8')
            log(f"Received welcome message: {welcome_message}", level=3)

            received_data = ""  # Buffer to accumulate received data

            # Loop until max words generated
            for _ in range(max_words):
                next_word = None

                log(f"Current prompt: {' '.join(prompt_words)}", level=3)

                # Prepare the message to send with "??" appended
                message = "classify " + " ".join(prompt_words) + " ??\n"
                client_socket.sendall(message.encode('utf-8'))

                # Loop to receive data until a complete category line is found
                while next_word is None:
                    data = client_socket.recv(1024).decode('utf-8')
                    received_data += data
                    log(f"Received data (accumulated): {received_data}", level=3)

                    # Process complete lines from the buffer
                    lines = received_data.splitlines()
                    # Keep the last incomplete line in the buffer
                    if received_data.endswith('\n'):
                        received_data = ""
                    else:
                        received_data = lines[-1]
                        lines = lines[:-1]


                    for line in lines:
                        log(f"Processing line: {line}", level=3)
                        if line.startswith("CATEGORY "):
                            # Extract the word inside curly brackets
                            match = re.search(r"\{(.*?)\}", line)
                            if match:
                                next_word = match.group(1)  # Extract the predicted word inside `{}`
                                log(f"Extracted next word: {next_word}", level=3)
                                break  # Found the category line, break inner loop
                    #if next_word is None:
                    #    time.sleep(0.1) # Wait a bit before trying to receive more data

                #time.sleep(0.15)

                # Add the predicted word to the generated text only if it's not None
                if next_word is not None:
                  generated_tokens.append(next_word)
                  # Print the generated token after converting it to a string
                  print(tokenizer.convert_tokens_to_string([next_word]), end="")

                  log(f"Predicted word: {next_word}", level=3)

                  # Shift prompt words and add the new word
                  prompt_words = prompt_words[1:] + [next_word]
                else:
                  log("Could not extract next word from server response.", level=2)
                  # Optionally, handle the case where no word is predicted, e.g., stop or use a placeholder
                  break # Stop generating if no word is predicted


                # Stop if a period is generated
                #if next_word == ".":
                #    break


        # Detokenize the generated tokens
        # generated_text = tokenizer.convert_tokens_to_string(generated_tokens)

        # Strip off original padding characters
        # generated_text = generated_text.replace("_", "").strip()


        # Print the final generated text
        # log(f"Generated text: {generated_text}", level=1)

    except Exception as e:
        log(f"Error: {e}", level=1)

    # Add a newline at the end for cleaner output
    print()

In [None]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')

# Example usage
host = '85.215.105.128'  # Server's IP address or hostname
port = 8001              # Port number on the server
initial_prompt = input("Enter your initial prompt: ")

generate_text_from_server(host, port, initial_prompt)