# Scene Generation Prototype

Check if you have CUDA enabled.

In [15]:
import torch
print(torch.cuda.is_available())

True


If not, remove ", device=0" from the code cell.

Also make sure to provide the correct path to your model that you're using.

In [32]:
from transformers import pipeline
pipe = pipeline('text-generation', model="<ENTER MODEL PATH HERE>", tokenizer="dbmdz/german-gpt2", device=0)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Check whether you can prompt the model

In [33]:
output = pipe("<s>", return_full_text=False, max_new_tokens=50)[0]["generated_text"]
print(output)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


UNKNOWN: Was sind nun wieder?MALE: Herr Baron!MALE: Du lest mir das Blatt, mein Junge! Zu dir, zu wem?MALE: O, fürwahr, dein Vater wollte


---------------
Run both code cells, then open the link

In [34]:
# Dramatic Networks
import numpy as np

def generate_choice(choices):
  values_from_choices = list(choices.values())
  keys_from_choices = list(choices.keys())
  number = np.random.uniform()
  for i in range(1, len(values_from_choices)+1):
    if number < sum(values_from_choices[:i]):
      return(keys_from_choices[i-1])

def end_scene(end_prob):
  number = np.random.uniform()
  if number < end_prob:
    return True
  else:
    return False

def simulate_sequence(centrality, loyalty, scene_length_parameter):
  sequence = []
  end_prob = 0.01
  current_char = generate_choice(centrality)
  sequence.append(current_char)

  for i in range(20):
    next_char = generate_choice(loyalty[current_char])
    if next_char == "X":
      next_char = generate_choice(centrality)
    current_char = next_char
    sequence.append(current_char)

    end_prob = end_prob * scene_length_parameter
    if end_scene(end_prob):
      break
  return sequence


In [35]:
import gradio as gr
import re


### Methods and variables
centrality = {
    "A": 0.8,
    "B": 0.2,
    "C": 0.2
}

loyalty = {
    "A": {
        "B": 0.1,
        "C": 0.2,
        "X": 0.7
    },
    "B": {
        "A": 0.8,
        "C": 0.1,
        "X": 0.1
    },
    "C": {
        "A": 0.8,
        "B": 0.1,
        "X": 0.1
    }
}

display_history = []
internal_history = []

gender_dict = dict()
character_dict = dict()

character_sequence = []

parameters = {"top_p": 1,
              "top_k": 50,
              "temperature": 1,
              "repetition_penalty": 1,
              "context_length": 3
              }

def dialogue_init(top_p, top_k, temperature, repetition_penalty, scene_length, context_length, nameA, nameB, nameC, genderA, genderB, genderC, textA, textB, textC, centr_A, centr_B, centr_C, loyalty_A_B, loyalty_A_C, loyalty_A_X, loyalty_B_A, loyalty_B_C, loyalty_B_X, loyalty_C_A, loyalty_C_B, loyalty_C_X):
    # Initialize the dialogue with the character names and the initial dialogue lines for them
    internal_history.clear()
    display_history.clear()

    centrality["A"] = centr_A
    centrality["B"] = centr_B
    centrality["C"] = centr_C
    loyalty["A"]["B"] = loyalty_A_B
    loyalty["A"]["C"] = loyalty_A_C
    loyalty["A"]["X"] = loyalty_A_X
    loyalty["B"]["A"] = loyalty_B_A
    loyalty["B"]["C"] = loyalty_B_C
    loyalty["B"]["X"] = loyalty_B_X
    loyalty["C"]["A"] = loyalty_C_A
    loyalty["C"]["B"] = loyalty_C_B
    loyalty["C"]["X"] = loyalty_C_X

    gender_dict[nameA] = genderA
    gender_dict[nameB] = genderB
    gender_dict[nameC] = genderC

    character_dict["A"] = nameA
    character_dict["B"] = nameB
    character_dict["C"] = nameC

    if len(textA) > 0:
        internal_history.append(f"<s>{genderA}: {textA}")
        display_history.append(f"{nameA}: {textA}")
    if len(textB) > 0:
        internal_history.append(f"<s>{genderB}: {textB}")
        display_history.append(f"{nameB}: {textB}")
    if len(textC) > 0:
        internal_history.append(f"<s>{genderC}: {textC}")
        display_history.append(f"{nameC}: {textC}")

    character_sequence.clear()
    for c in simulate_sequence(centrality, loyalty, scene_length):
        character_sequence.append(c)

    parameters["context_length"] = context_length
    parameters["top_p"] = top_p
    parameters["top_k"] = top_k
    parameters["temperature"] = temperature
    parameters["repetition_penalty"] = repetition_penalty

    return "Dialogue initialized with chosen settings\n\n" + "For debugging: "+ ",".join(character_sequence) + "\n\n" + show_dialogue()


def show_dialogue():
    # Display the history
    return "\n".join(display_history)


def choose_next_character():
    if len(character_sequence) > 0:
        return character_sequence.pop(0)
    else:
        return "END"


def generate_next_line():
    new_character = choose_next_character()
    if new_character == "END":
        return show_dialogue() + "\n----End of scene----"

    new_character = character_dict[new_character]

    cutoff_text = generate_line(" ".join(internal_history[-parameters["context_length"]:]) + "<s>" + gender_dict[new_character] + ":")

    display_history.append(f"{new_character}: {cutoff_text}")
    internal_history.append(f"<s>{gender_dict[new_character]}: {cutoff_text}")
    return show_dialogue()


def cutoff_next_line(line):
    pattern = r".+?(?=(FEMALE|MALE|UNKNOWN))"
    match = re.search(pattern, line)
    if match:
        first_dialog = match.group(0)
        return first_dialog
    else:
        return "---"


def generate_line(context):
    new_text = pipe(context, return_full_text=False, max_new_tokens=50, do_sample=True, top_p=parameters["top_p"], top_k=parameters["top_k"],temperature=parameters["temperature"], repetition_penalty=parameters["repetition_penalty"])[0]["generated_text"]
    cutoff_text = cutoff_next_line(new_text)
    if cutoff_text != "---":
        return cutoff_text
    else:
        return generate_line(context)


def revise_last_line():
    last_line_internal = internal_history.pop()
    gender, line = last_line_internal.split("<s>")[1].split(": ", maxsplit=1)
    last_line_display = display_history.pop()
    name = last_line_display.split(": ", maxsplit=1)[0]

    cutoff_text = generate_line(" ".join(internal_history[-parameters["context_length"]:]) + "<s>" + gender + ":")
    display_history.append(f"{name}: {cutoff_text}")
    internal_history.append(f"<s>{gender}: {cutoff_text}")
    return show_dialogue()


with gr.Blocks() as demo:
    gr.Markdown("# Theater Scene Generation Prototype")
    gr.Markdown("## Initialization")
    gr.Markdown("### Text generation parameters:")
    top_p = gr.Slider(minimum=0.5, maximum=1, value=1, step=0.01, interactive=True, label="top_p: How many tokens to consider (in terms of probability). | Lower values = Discard low-probability tokens, 1 = Consider all tokens for sampling")
    top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, interactive=True, label="top_k: How many tokens to consider (in terms of number of tokens). | Values must be between 1 and 100")
    temperature = gr.Slider(minimum=0, maximum=2, value=1, step=0.05, interactive=True, label="temperature: Modulates the token distribution. | Low = Greedy generation, High = Probability mass is more uniformly distributed")
    repetition_penalty = gr.Slider(minimum=1, maximum=1.5, value=1, step=0.01, interactive=True, info="Very experimental. Recommended value is 1. If text generation times out, it is likely caused by this parameter",label="Repetition penalty | Low = No penalty, High = Avoid previously generated tokens")


    gr.Markdown("### Other parameters:")
    scene_length = gr.Slider(minimum=1.1, maximum=3, value=1.5, step=0.1, interactive=True, label="Scene length parameter: How long should the generated scene be by tendency? | 1 = longer; 3 = shorter")
    context_length = gr.Slider(minimum=1, maximum=10, value=3, step=1, interactive=True, label="Context length: How many of the last lines should be used as the prompt. | When using a CPU, 2-5 is recommended.")


    gr.Markdown("### Characters: ")
    with gr.Row():
        with gr.Column():
            gr.Markdown("#### Character A:")
            nameA = gr.Textbox(label="Name of character A")
            genderA = gr.Dropdown(["MALE", "FEMALE", "UNKNOWN"], value="UNKNOWN", label="Gender of character A")
            textA = gr.Textbox(label="First line of character A")
        with gr.Column():
            gr.Markdown("#### Character B:")
            nameB = gr.Textbox(label="Name of character B")
            genderB = gr.Dropdown(["MALE", "FEMALE", "UNKNOWN"], value="UNKNOWN", label="Gender of character B")
            textB = gr.Textbox(label="First line of character B")
        with gr.Column():
            gr.Markdown("#### Character C:")
            nameC = gr.Textbox(label="Name of character C")
            genderC = gr.Dropdown(["MALE", "FEMALE", "UNKNOWN"], value="UNKNOWN", label="Gender of character C")
            textC = gr.Textbox(label="First line of character C")

    gr.Examples(examples=[["Nathan", "MALE", "So macht nur, dass er euch hier nicht gewahr wird. Tretet mehr zurück. Geht lieber ganz hinein.",
                          "Recha", "FEMALE", "Nur einen Blick noch! — Ah! die Hecke, die mir ihn stiehlt.",
                          "Daja", "FEMALE", "Kommt! kommt! Der Vater hat ganz recht. Ihr lauft Gefahr, wenn er Euch sieht, dass auf der Stell’ er umkehrt."
                          ]], inputs=[nameA, genderA, textA, nameB, genderB, textB, nameC, genderC, textC] )


    gr.Markdown("### Character parameters:")
    gr.Markdown("**Important:**\n Centrality values have to add up to 1. (e.g., A: 0.8, B: 0.2, C: 0.2)\n Loyalty values for each character (within a column) have to add up to 1. (e.g., A->B: 0.1, A->C: 0.4, A->X: 0.5)")
    with gr.Row():
        with gr.Column():
            gr.Markdown("#### Parameters for character A:")
            centr_A = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Centrality of A", value=0.8)
            loyalty_A_B = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Loyalty of A to B", value=0.1)
            loyalty_A_C = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Loyalty of A to C", value=0.2)
            loyalty_A_X = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Chance that A will end dialogue", value=0.7)
        with gr.Column():
            gr.Markdown("#### Parameters for character B:")
            centr_B = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Centrality of B", value=0.2)
            loyalty_B_A = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Loyalty of B to A", value=0.8)
            loyalty_B_C = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Loyalty of B to C", value=0.1)
            loyalty_B_X = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Chance that B will end dialogue", value=0.1)
        with gr.Column():
            gr.Markdown("#### Parameters for character C:")
            centr_C = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Centrality of C", value=0.2)
            loyalty_C_A = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Loyalty of C to A", value=0.8)
            loyalty_C_B = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Loyalty of C to B", value=0.1)
            loyalty_C_X = gr.Slider(minimum=0, maximum=1, step=0.1, interactive=True, label="Chance that C will end dialogue", value=0.1)

    init_btn = gr.Button("Initialize play")

    gr.Markdown("## Generated scene")
    output = gr.Textbox(label="Scene")
    with gr.Row():
        with gr.Column():
            generate_btn = gr.Button("Generate next line")
        with gr.Column():
            revise_btn = gr.Button("Revise last line")


    ### Functionalities

    # Initialization
    init_btn.click(fn=dialogue_init, inputs=[top_p, top_k, temperature, repetition_penalty, scene_length, context_length,nameA, nameB, nameC, genderA, genderB, genderC, textA, textB, textC, centr_A, centr_B, centr_C, loyalty_A_B, loyalty_A_C, loyalty_A_X, loyalty_B_A, loyalty_B_C, loyalty_B_X, loyalty_C_A, loyalty_C_B, loyalty_C_X], outputs=output)

    # Generate next line
    generate_btn.click(fn=generate_next_line, inputs=None, outputs=output)

    # Revise last line
    revise_btn.click(fn=revise_last_line, inputs=None, outputs=output)


demo.launch()


Running on local URL:  http://127.0.0.1:7864

To create a public link, set `share=True` in `launch()`.


