In [1]:
%load_ext autoreload
%autoreload 2


In [9]:
from datasets import load_from_disk
from transformers import (
    RobertaTokenizer,
    T5ForConditionalGeneration,
)
import torch
import random
import json
from draw_svg import *
import threading

In [3]:
torch.set_num_threads(16)

In [3]:
tokenizer = RobertaTokenizer.from_pretrained("Salesforce/codet5-small")
model = T5ForConditionalGeneration.from_pretrained(
    "/data/nicolasmaier/model/ended-3"
)


In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
model_gpu = model.to(device)

cuda:0


In [5]:
dataset = load_from_disk("/data/nicolasmaier/dataset/hf_clean_seq_dataset_3")
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['code', 'contents', 'xmi', 'originalLine', 'input_ids', 'attention_mask', 'seq', 'labels'],
        num_rows: 366247
    })
    valid: Dataset({
        features: ['code', 'contents', 'xmi', 'originalLine', 'input_ids', 'attention_mask', 'seq', 'labels'],
        num_rows: 13022
    })
    test: Dataset({
        features: ['code', 'contents', 'xmi', 'originalLine', 'input_ids', 'attention_mask', 'seq', 'labels'],
        num_rows: 21563
    })
})


In [6]:
def fix_json(s):
    stack = []
    skip_next = False
    for i, c in enumerate(s):
        if skip_next:
            skip_next = False
            continue

        if c == "\\":
            skip_next = True
        elif c == "[":
            stack.append("[")
        elif c == "]":
            if stack[-1] != "[":
                print("warning: ] without [")
                s = s[:i] + " " + s[i + 1 :]
            else:
                stack.pop()
        elif c == "{":
            stack.append("{")
        elif c == "}":
            if stack[-1] != "{":
                print("warning: } without {")
                s = s[:i] + " " + s[i + 1 :]
            else:
                stack.pop()

    for c in stack[::-1]:
        if c == "[":
            print("warning: [ without ]")
            s += "]"
        elif c == "{":
            print("warning: { without }")
            s += "}"

    return s

def model_generate(code):
    model_input = tokenizer(code, return_tensors="pt").to(device)

    outputs = model_gpu.generate(
        model_input.input_ids,
        num_beams=10,
        max_length=510,
        # do_sample=True,
        temperature=0.3,
        top_k=50,
        top_p=0.8,
    )
    model_output = tokenizer.decode(outputs[0][2:-1])
    
    print("input:", code)
    print("output:", model_output)

    seq_text = None
    while seq_text is None:
        try:
            seq_text = json.loads(fix_json(model_output))
        except json.decoder.JSONDecodeError as e:
            if e.msg == "Extra data":
                print("warning: extra data in", model_output)
                model_output = model_output[: e.pos]
            else:
                print("error:")
                print(model_output)
                raise e

    return seq_text

In [14]:
with open("test_code.txt", "r", encoding="utf-8") as f:
    code = f.read()
    seq_text = model_generate(code)
    t_res = None
    def t_run():
        global t_res
        t_res = draw_svg(seq_text)
    t = threading.Thread(target=t_run)
    t.start()
    t.join()
    with open("out/model.svg", "wb") as f:
        f.write(t_res)

input: public void test(int x, String y) {
		A a = new A();
		
		a.foo1();
		a.foo2(42);
		
		while (a.foo3()) {
			B b = new B();
			
			b.bar1();
			
			if (b.bar2() && xyz) {
				b.bar3(42);
			} else {
				b.bar4();
			}
		}
	}
output: {"title": "test(x, y)", "sequence": [{"type": "newInstance", "new_type": "A"}, {"type": "scopedVariable", "name": "a"}, {"type": "methodInvocation", "to": ["a"], "method": "foo1()"}, {"type": "methodInvocation", "to": ["a"], "method": "foo2(42)"}, {"type": "methodInvocation", "to": ["a"], "method": "foo3()"}, {"type": "blocks", "name": "while", "blocks": [{"guard": "a.foo3()", "contents": [{"type": "newInstance", "new_type": "B"}, {"type": "scopedVariable", "name": "b"}, {"type": "methodInvocation", "to": ["b"], "method": "bar1()"}, {"type": "methodInvocation", "to": ["b"], "method": "bar2()"}, {"type": "blocks", "name": "if", "blocks": [{"guard": "b.bar2() && xyz", "contents": [{"type": "methodInvocation", "to": ["b"], "method": "bar3(42)"}]}, {"gua