### Multi Turn Program Synthesis

- Use Dropdown to select the model and device, click the button to load it
    - This can take some time depending on the model size

In [None]:
import ipywidgets as widgets
from ipywidgets import Button, HBox, VBox
from IPython.display import display, clear_output

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


In [None]:
def load_model(model_name, device, precision):
    _dtype = torch.float16 if precision == "float16" else torch.float32
    
    assert ((precision == "float16" and "cuda" in device) or precision != "float16"), (
        "'float16' is only supported when using the GPU")
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=_dtype).to(device)
    
    tokenizer.pad_token = 50256

    return tokenizer, model

In [None]:
def truncate(completion):
    import re
    
    def find_re(string, pattern, start_pos):
        m = pattern.search(string, start_pos)
        return m.start() if m else -1

    terminals = [re.compile(r, re.MULTILINE) for r in ['^#', re.escape('<|endoftext|>'), "^'''", '^"""', '\n\n\n']]

    prints = list(re.finditer('^print', completion, re.MULTILINE))
    if len(prints) > 1:
        completion = completion[:prints[1].start()]

    defs = list(re.finditer('^def', completion, re.MULTILINE))
    if len(defs) > 1:
        completion = completion[:defs[1].start()]

    start_pos = 0

    terminals_pos = [pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1]
    if len(terminals_pos) > 0:
        return completion[:min(terminals_pos)]
    else:
        return completion   

In [None]:
def generate(prompt, tokenizer, model, device):
    input_ids = tokenizer(prompt, truncation=True, padding=True, return_tensors="pt").input_ids.to(device)
    input_ids_len = input_ids.shape[1]
    
    generated_ids = model.generate(input_ids, do_sample=True, num_return_sequences=1, temperature=0.2, max_new_tokens=256, top_p=0.95, pad_token_id=50256, use_cache=True)
    
    # decoded_generation = tokenizer.decode(generated_ids[0][input_ids.shape[1]:].to("cpu"), skip_special_tokens=True)
    decoded_generation = tokenizer.batch_decode(generated_ids[:, input_ids_len:])
    
    return decoded_generation

In [None]:
tokenizer, model, device, precision = None, None, None, None

model_chooser = widgets.Dropdown(
    options=["Salesforce/codegen-350M-mono", "Salesforce/codegen-2B-mono", "Salesforce/codegen-6B-mono", "Salesforce/codegen-16B-mono"],
    value="Salesforce/codegen-350M-mono",
    description='Model:',
    disabled=False
)

precision_chooser = widgets.Dropdown(
    options=["float32", "float16"],
    value="float16",
    description='Precision:',
    disabled=False
)

device_chooser = widgets.Dropdown(
    options=["cpu", "cuda:0"],
    value="cuda:0",
    description='Device:',
    disabled=False
)

model_button = widgets.Button(
    description="Load Model",
    disabled=False,
    tooltip="Loads the model, selected in the dropdown"
)

chooser_output = widgets.Output()

def on_model_button_clicked(b):
    global tokenizer, model, device, precision

    model_name = model_chooser.value
    device = device_chooser.value
    precision = precision_chooser.value
    
    with chooser_output:
        clear_output()
        print(f"Loading model {model_name}")
        tokenizer, model = load_model(model_name, device, precision)
        print("Loading finished")


model_button.on_click(on_model_button_clicked)

        
chooser_layout = HBox([model_chooser, device_chooser, precision_chooser, model_button])
chooser_layout_main = VBox([chooser_layout, chooser_output])

display(chooser_layout_main)


In [None]:
text_area = widgets.Textarea(
    value='Hello World',
    placeholder='Type something',
    description='String:',
    disabled=False
)

button = widgets.Button(
    description='Click me',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click me'
)

output = widgets.Output()

layout = HBox([text_area, button])
layout_main = VBox([layout, output])

# multi_turn = True

_wrap = lambda prompt: f"# {prompt}\n"

# histories = ["# Import libraries.\n\nimport numpy as np\n\n"]
histories = ["# Import libraries.\n\nimport numpy as np\n\n"]
prompt_count = 0

def on_button_clicked(b):
    with output:
        print("Generating...")
    
    global histories, prompt_count
    current_input = text_area.value
    
    histories = [h + _wrap(current_input) for h in histories]
    completions = generate(histories, tokenizer, model, device)
    histories = [h + f"{truncate(c)}\n\n" for h, c in zip(histories, completions)]
    
    with output:      
        print("-" * 10)
        print(prompt_count)
        print("-" * 10)
        print(histories[0])
        print("-" * 10)
        
    prompt_count += 1


button.on_click(on_button_clicked)

display(layout_main)