# Simple textgen setup

The following notebook allows you to generate text using a pretrained model and tokenizer in a friendly interface.

## Setup cells
Only run these cells once

In [1]:
from transformers import GPTNeoForCausalLM, AutoTokenizer
from IPython.display import display
import ipywidgets as widgets
import time
import torch

2023-03-11 15:48:59.666251: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Below you may set the model and tokenizer you want to use. You can find a list of available models [here](https://huggingface.co/models?pipeline_tag=text-generation). By default, the notebook uses the `gpt-neo-2.7B` model with its tokenizer.

In [2]:
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")

In [3]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")

## Text generation setup

The following cells define the helper functions that will be used for text generation. You shouldn't need to change anything here.

In [4]:
# Helpful constants
TEMPERATURE = 0.9

In [5]:
def genFromString(prompt, num_tokens, temperature=TEMPERATURE):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    max_length = input_ids.shape[1] + num_tokens
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
    output = model.generate(input_ids, max_length=max_length, do_sample=True, temperature=temperature, attention_mask=attention_mask, pad_token_id=tokenizer.eos_token_id)
    return (tokenizer.decode(output[0], skip_special_tokens=True), len(output[0]) - len(input_ids[0]))


In [6]:
#print(genFromString("Hey everyone! Welcome to the stream. Today, we'll be playing ", 75))

## Iterative generation setup

These cells allow the iterative generator to be used, which allows you to generate text iteratively, only generating a certain amount of tokens at a time. This is useful for generating long texts, as it allows you to generate the text in chunks, and then combine the chunks together.

In [7]:
# This function is merely a helper function for rendering the time
def formatTime(seconds):
    # Formats the time in a human-readable way
    # XXms
    # or XXsXXms
    # or XXmXXsXXms
    
    if seconds < 1:
        return '{:.2f}ms'.format(seconds * 1000)
    elif seconds < 60:
        return '{:.2f}s'.format(seconds)
    else:
        return '{:.2f}m{:.2f}s'.format(seconds // 60, seconds % 60)

## Generator

The following cell contains the functions that are used for creating the generator. You can think of this as the "frontend" of the notebook, as it contains the functions that are used to interact with the generator.

In [8]:
def iterativeGen(initial_prompt, total_tokens=100, chunk_size=50, temperature=TEMPERATURE):
    prompt = initial_prompt
    tokens_generated = 0
    while tokens_generated < total_tokens:
        # Generate the next tokens
        next_tokens, num_tokens = genFromString(prompt, chunk_size, temperature=temperature)

        # The text generated is the prompt for the next iteration
        prompt = next_tokens

        # Update the number of tokens generated
        tokens_generated += num_tokens
        yield (prompt, tokens_generated)

    return

In [9]:
async def generator():
    # Create a widget that will display the generated text
    text = widgets.Textarea()
    # Set the text area to be wider than the default
    text.layout.width = '75%'
    
    # Create a button to generate "more text"
    moreButton = widgets.Button(description='More')
    progressBar = widgets.IntProgress(
        min = 0,
        max = 100,
        bar_style='success',
        value = 0,
    )
    
    # Create a slider to control the number of tokens per chunk
    chunkSlider = widgets.IntSlider(
        value = 50,
        min = 1,
        max = 100,
        description = 'Tokens/chk:',
        orientation = 'horizontal',
    )
    
    # Create a slider to control the temperature
    tempSlider = widgets.FloatSlider(
        value = TEMPERATURE,
        min = 0.1,
        max = 1.0,
        step = 0.1,
        description = 'Temp:',
        orientation = 'horizontal',
    )
    
    def on_button_clicked(b):
        print('Generating more text...')
        start = time.time()
        text.disabled = True
        moreButton.disabled = True
        chunkSlider.disabled = True
        tempSlider.disabled = True
        progressBar.bar_style = 'info'
        progressBar.value = 0
        times = []
        for prompt, tokens_generated in iterativeGen(text.value, 100, chunkSlider.value, temperature=tempSlider.value):
            text.value = prompt
            progressBar.value = tokens_generated
            times += [time.time() - start]
            start = time.time()
        text.disabled = False
        moreButton.disabled = False
        chunkSlider.disabled = False
        tempSlider.disabled = False
        progressBar.bar_style = 'success'
        
        print('Done in {:.2f} seconds (avg {:.2f}s/chunk, {:.2f}s/token)! Generated {} total tokens over {} chunks.'.format(sum(times), sum(times) / len(times), sum(times) / tokens_generated, tokens_generated, len(times)))
    
    moreButton.on_click(on_button_clicked)
    # Create a vertical box to hold the buttons
    vbox = widgets.VBox([moreButton, chunkSlider, tempSlider, progressBar])
    
    # Create a horizontal box to hold the text and buttons
    hbox = widgets.HBox([text, vbox])
    display(hbox)
    
    

# Generator

This final cell creates the generator interface, which allows you to generate text using the model and tokenizer you selected above.

To use it, begin by writing some text in the textbox below. Then, click on the "More" button to launch the generation.

You may also use the two sliders to control the amount of text generated per chunk and the temperature of the generation. Lower tokens per chunk will display the text more quickly, but might decrease the quality of the text. Higher temperatures will increase the diversity of the text, but might decrease the quality of the text.

In [10]:
await generator()

HBox(children=(Textarea(value='', layout=Layout(width='75%')), VBox(children=(Button(description='More', style…

Generating more text...
Done in 42.33 seconds (avg 21.16s/chunk, 0.42s/token)! Generated 100 total tokens over 2 chunks.
